diff --git a/.gitignore b/.gitignore
index d0f4500..0e05091 100644
--- a/.gitignore
+++ b/.gitignore
@@ -50,3 +50,4 @@ marltoolbox/examples/Tutorial_*.py
marltoolbox/experiments/Tutorial_*.py
api_key_wandb
wandb/
+requirements.txt
diff --git a/marltoolbox/algos/amTFT/__init__.py b/marltoolbox/algos/amTFT/__init__.py
index a8d5f52..d3d651e 100644
--- a/marltoolbox/algos/amTFT/__init__.py
+++ b/marltoolbox/algos/amTFT/__init__.py
@@ -19,3 +19,21 @@
AmTFTRolloutsTorchPolicy,
)
from marltoolbox.algos.amTFT.train_helper import train_amtft
+
+
+__all__ = [
+ "train_amtft",
+ "AmTFTRolloutsTorchPolicy",
+ "amTFTQValuesTorchPolicy",
+ "Level1amTFTExploiterTorchPolicy",
+ "observation_fn",
+ "AmTFTCallbacks",
+ "WORKING_STATES",
+ "WORKING_STATES_IN_EVALUATION",
+ "AmTFTReferenceClass",
+ "DEFAULT_NESTED_POLICY_COOP",
+ "DEFAULT_NESTED_POLICY_SELFISH",
+ "PLOT_ASSEMBLAGE_TAGS",
+ "PLOT_KEYS",
+ "DEFAULT_CONFIG",
+]
diff --git a/marltoolbox/algos/amTFT/base.py b/marltoolbox/algos/amTFT/base.py
index 6c383a3..a45697f 100644
--- a/marltoolbox/algos/amTFT/base.py
+++ b/marltoolbox/algos/amTFT/base.py
@@ -4,20 +4,23 @@
from ray.rllib.utils import merge_dicts
from marltoolbox.algos import hierarchical, augmented_dqn
-from marltoolbox.utils import \
- postprocessing, miscellaneous
+from marltoolbox.utils import postprocessing, miscellaneous
logger = logging.getLogger(__name__)
APPROXIMATION_METHOD_Q_VALUE = "amTFT_use_Q_net"
APPROXIMATION_METHOD_ROLLOUTS = "amTFT_use_rollout"
-APPROXIMATION_METHODS = (APPROXIMATION_METHOD_Q_VALUE,
- APPROXIMATION_METHOD_ROLLOUTS)
-WORKING_STATES = ("train_coop",
- "train_selfish",
- "eval_amtft",
- "eval_naive_selfish",
- "eval_naive_coop")
+APPROXIMATION_METHODS = (
+ APPROXIMATION_METHOD_Q_VALUE,
+ APPROXIMATION_METHOD_ROLLOUTS,
+)
+WORKING_STATES = (
+ "train_coop",
+ "train_selfish",
+ "eval_amtft",
+ "eval_naive_selfish",
+ "eval_naive_coop",
+)
WORKING_STATES_IN_EVALUATION = WORKING_STATES[2:]
OWN_COOP_POLICY_IDX = 0
@@ -29,8 +32,9 @@
DEFAULT_NESTED_POLICY_COOP = DEFAULT_NESTED_POLICY_SELFISH.with_updates(
postprocess_fn=miscellaneous.merge_policy_postprocessing_fn(
postprocessing.welfares_postprocessing_fn(
- add_utilitarian_welfare=True, ),
- postprocess_nstep_and_prio
+ add_utilitarian_welfare=True,
+ ),
+ postprocess_nstep_and_prio,
)
)
@@ -44,31 +48,35 @@
"rollout_length": 40,
"n_rollout_replicas": 20,
"last_k": 1,
+ "punish_instead_of_selfish": False,
# TODO use log level of RLLib instead of mine
"verbose": 1,
"auto_load_checkpoint": True,
-
# To configure
"own_policy_id": None,
"opp_policy_id": None,
"callbacks": None,
# One from marltoolbox.utils.postprocessing.WELFARES
"welfare_key": postprocessing.WELFARE_UTILITARIAN,
-
- 'nested_policies': [
+ "nested_policies": [
# Here the trainer need to be a DQNTrainer
# to provide the config for the 3 DQNTorchPolicy
- {"Policy_class": DEFAULT_NESTED_POLICY_COOP,
- "config_update": {}},
- {"Policy_class": DEFAULT_NESTED_POLICY_SELFISH,
- "config_update": {}},
- {"Policy_class": DEFAULT_NESTED_POLICY_COOP,
- "config_update": {}},
- {"Policy_class": DEFAULT_NESTED_POLICY_SELFISH,
- "config_update": {}},
+ {"Policy_class": DEFAULT_NESTED_POLICY_COOP, "config_update": {}},
+ {
+ "Policy_class": DEFAULT_NESTED_POLICY_SELFISH,
+ "config_update": {},
+ },
+ {"Policy_class": DEFAULT_NESTED_POLICY_COOP, "config_update": {}},
+ {
+ "Policy_class": DEFAULT_NESTED_POLICY_SELFISH,
+ "config_update": {},
+ },
],
- "optimizer": {"sgd_momentum": 0.0, },
- }
+ "optimizer": {
+ "sgd_momentum": 0.0,
+ },
+ "batch_mode": "complete_episodes",
+ },
)
PLOT_KEYS = [
@@ -76,7 +84,8 @@
"debit",
"debit_threshold",
"summed_debit",
- "summed_n_steps_to_punish"
+ "summed_n_steps_to_punish",
+ "reset_rnn_state",
]
PLOT_ASSEMBLAGE_TAGS = [
@@ -85,6 +94,7 @@
("debit_threshold",),
("summed_debit",),
("summed_n_steps_to_punish",),
+ ("reset_rnn_state",),
]
diff --git a/marltoolbox/algos/amTFT/base_policy.py b/marltoolbox/algos/amTFT/base_policy.py
index 33c63d0..9b8978e 100644
--- a/marltoolbox/algos/amTFT/base_policy.py
+++ b/marltoolbox/algos/amTFT/base_policy.py
@@ -1,14 +1,19 @@
import copy
import logging
-from typing import List, Union, Optional, Dict, Tuple, TYPE_CHECKING
+from typing import Dict, TYPE_CHECKING
import numpy as np
+import torch
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import override
-from ray.rllib.utils.typing import TensorType, PolicyID
+from ray.rllib.utils.threading import with_lock
+from ray.rllib.utils.torch_ops import (
+ convert_to_torch_tensor,
+)
+from ray.rllib.utils.typing import PolicyID
if TYPE_CHECKING:
from ray.rllib.evaluation import RolloutWorker
@@ -37,21 +42,32 @@ def __init__(self, observation_space, action_space, config, **kwargs):
self.total_debit = 0
self.n_steps_to_punish = 0
self.observed_n_step_in_current_epi = 0
- self._first_fake_step_played = False
+ self.last_own_algo_idx_in_eval = OWN_COOP_POLICY_IDX
self.opp_previous_obs = None
self.opp_new_obs = None
self.own_previous_obs = None
self.own_new_obs = None
self.both_previous_raw_obs = None
self.both_new_raw_obs = None
+ self.coop_own_rnn_state_before_last_act = None
+ self.coop_opp_rnn_state_before_last_act = None
# notation T in the paper
self.debit_threshold = config["debit_threshold"]
# notation alpha in the paper
self.punishment_multiplier = config["punishment_multiplier"]
self.working_state = config["working_state"]
+ assert (
+ self.working_state in WORKING_STATES
+ ), f"self.working_state {self.working_state}"
self.verbose = config["verbose"]
self.welfare_key = config["welfare_key"]
self.auto_load_checkpoint = config.get("auto_load_checkpoint", True)
+ self.punish_instead_of_selfish = config.get(
+ "punish_instead_of_selfish", False
+ )
+ self.punish_instead_of_selfish_key = (
+ postprocessing.OPPONENT_NEGATIVE_REWARD
+ )
if self.working_state in WORKING_STATES_IN_EVALUATION:
self._set_models_for_evaluation()
@@ -60,45 +76,32 @@ def __init__(self, observation_space, action_space, config, **kwargs):
self.auto_load_checkpoint
and restore.LOAD_FROM_CONFIG_KEY in config.keys()
):
- restore.after_init_load_policy_checkpoint(self)
+ restore.before_loss_init_load_policy_checkpoint(self)
def _set_models_for_evaluation(self):
for algo in self.algorithms:
algo.model.eval()
+ @with_lock
@override(hierarchical.HierarchicalTorchPolicy)
- def compute_actions(
- self,
- obs_batch: Union[List[TensorType], TensorType],
- state_batches: Optional[List[TensorType]] = None,
- prev_action_batch: Union[List[TensorType], TensorType] = None,
- prev_reward_batch: Union[List[TensorType], TensorType] = None,
- info_batch: Optional[Dict[str, list]] = None,
- episodes: Optional[List["MultiAgentEpisode"]] = None,
- explore: Optional[bool] = None,
- timestep: Optional[int] = None,
- **kwargs,
- ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
-
- self._select_witch_algo_to_use()
- actions, state_out, extra_fetches = self.algorithms[
- self.active_algo_idx
- ].compute_actions(
- obs_batch,
- state_batches,
- prev_action_batch,
- prev_reward_batch,
- info_batch,
- episodes,
- explore,
- timestep,
- **kwargs,
+ def _compute_action_helper(
+ self, input_dict, state_batches, seq_lens, explore, timestep
+ ):
+ state_batches = self._select_witch_algo_to_use(state_batches)
+
+ self._track_last_coop_rnn_state(state_batches)
+ actions, state_out, extra_fetches = super()._compute_action_helper(
+ input_dict, state_batches, seq_lens, explore, timestep
)
- if self.verbose > 2:
- print(f"self.active_algo_idx {self.active_algo_idx}")
+ if self.verbose > 1:
+ print("algo idx", self.active_algo_idx, "action", actions)
+ print("extra_fetches", extra_fetches)
+ print("state_batches (in)", state_batches)
+ print("state_out", state_out)
+
return actions, state_out, extra_fetches
- def _select_witch_algo_to_use(self):
+ def _select_witch_algo_to_use(self, state_batches):
if (
self.working_state == WORKING_STATES[0]
or self.working_state == WORKING_STATES[4]
@@ -109,14 +112,17 @@ def _select_witch_algo_to_use(self):
or self.working_state == WORKING_STATES[3]
):
self.active_algo_idx = OWN_SELFISH_POLICY_IDX
+
elif self.working_state == WORKING_STATES[2]:
- self._select_algo_to_use_in_eval()
+ state_batches = self._select_algo_to_use_in_eval(state_batches)
else:
raise ValueError(
f'config["working_state"] ' f"must be one of {WORKING_STATES}"
)
- def _select_algo_to_use_in_eval(self):
+ return state_batches
+
+ def _select_algo_to_use_in_eval(self, state_batches):
if self.n_steps_to_punish == 0:
self.active_algo_idx = OWN_COOP_POLICY_IDX
elif self.n_steps_to_punish > 0:
@@ -125,23 +131,58 @@ def _select_algo_to_use_in_eval(self):
else:
raise ValueError("self.n_steps_to_punish can't be below zero")
+ state_batches = self._check_for_rnn_state_reset(
+ state_batches, "last_own_algo_idx_in_eval"
+ )
+
+ return state_batches
+
+ def _check_for_rnn_state_reset(self, state_batches, last_algo_idx: str):
+ if getattr(self, last_algo_idx) != self.active_algo_idx:
+ state_batches = self._get_initial_rnn_state(state_batches)
+ self._to_log["reset_rnn_state"] = self.active_algo_idx
+ setattr(self, last_algo_idx, self.active_algo_idx)
+ if self.verbose > 0:
+ print("reset_rnn_state")
+ else:
+ if "reset_rnn_state" in self._to_log.keys():
+ self._to_log.pop("reset_rnn_state")
+ return state_batches
+
+ def _get_initial_rnn_state(self, state_batches):
+ if "model" in self.config.keys() and self.config["model"]["use_lstm"]:
+ initial_state = self.algorithms[
+ self.active_algo_idx
+ ].get_initial_state()
+ initial_state = [
+ convert_to_torch_tensor(s, self.device) for s in initial_state
+ ]
+ initial_state = [s.unsqueeze(0) for s in initial_state]
+ msg = (
+ f"self.active_algo_idx {self.active_algo_idx} "
+ f"state_batches {state_batches} reset to initial rnn state"
+ )
+ # print(msg)
+ logger.info(msg)
+ return initial_state
+ else:
+ return state_batches
+
+ def _track_last_coop_rnn_state(self, state_batches):
+ if self.active_algo_idx == OWN_COOP_POLICY_IDX:
+ self.coop_own_rnn_state_before_last_act = state_batches
+
@override(hierarchical.HierarchicalTorchPolicy)
def _learn_on_batch(self, samples: SampleBatch):
- # working_state_idx = WORKING_STATES.index(self.working_state)
- # assert working_state_idx == OWN_COOP_POLICY_IDX \
- # or working_state_idx == OWN_SELFISH_POLICY_IDX, \
- # f"current working_state is {self.working_state} " \
- # f"but must be one of " \
- # f"[{WORKING_STATES[OWN_COOP_POLICY_IDX]}, " \
- # f"{WORKING_STATES[OWN_SELFISH_POLICY_IDX]}]"
-
if self.working_state == WORKING_STATES[0]:
algo_idx_to_train = OWN_COOP_POLICY_IDX
elif self.working_state == WORKING_STATES[1]:
algo_idx_to_train = OWN_SELFISH_POLICY_IDX
else:
- raise ValueError()
+ raise ValueError(
+ f"self.working_state must be one of " f"{WORKING_STATES[0:2]}"
+ )
samples = self._modify_batch_for_policy(algo_idx_to_train, samples)
algo_to_train = self.algorithms[algo_idx_to_train]
@@ -158,14 +199,25 @@ def _learn_on_batch(self, samples: SampleBatch):
def _modify_batch_for_policy(self, algo_idx_to_train, samples):
if algo_idx_to_train == OWN_COOP_POLICY_IDX:
samples = samples.copy()
- samples = self._overwrite_reward_for_policy_in_use(samples)
+ samples = self._overwrite_reward_for_policy_in_use(
+ samples, self.welfare_key
+ )
+ elif (
+ self.punish_instead_of_selfish
+ and algo_idx_to_train == OWN_SELFISH_POLICY_IDX
+ ):
+ samples = samples.copy()
+ samples = self._overwrite_reward_for_policy_in_use(
+ samples, self.punish_instead_of_selfish_key
+ )
+
return samples
- def _overwrite_reward_for_policy_in_use(self, samples_copy):
+ def _overwrite_reward_for_policy_in_use(self, samples_copy, welfare_key):
samples_copy[samples_copy.REWARDS] = np.array(
- samples_copy.data[self.welfare_key]
+ samples_copy.data[welfare_key]
)
- logger.debug(f"overwrite reward with {self.welfare_key}")
+ logger.debug(f"overwrite reward with {welfare_key}")
return samples_copy
def on_observation_fn(self, own_new_obs, opp_new_obs, both_new_raw_obs):
@@ -173,7 +225,7 @@ def on_observation_fn(self, own_new_obs, opp_new_obs, both_new_raw_obs):
# observation produced by this action. But we need the
# observation that cause the agent to play this action
# thus the observation n-1
- if self._first_fake_step_played:
+ if self.own_new_obs is not None:
self.own_previous_obs = self.own_new_obs
self.opp_previous_obs = self.opp_new_obs
self.both_previous_raw_obs = self.both_new_raw_obs
@@ -193,20 +245,17 @@ def on_episode_step(
*args,
**kwargs,
):
- if self._first_fake_step_played:
- opp_obs, raw_obs, opp_a = self._get_information_from_opponent(
- policy_id, policy_ids, episode
- )
+ opp_obs, raw_obs, opp_a = self._get_information_from_opponent(
+ policy_id, policy_ids, episode
+ )
- # Ignored the first step in epi because the
- # actions provided are fake (they were not played)
- self._on_episode_step(
- opp_obs, raw_obs, opp_a, worker, base_env, episode, env_index
- )
+ # Ignored the first step in epi because the
+ # actions provided are fake (they were not played)
+ self._on_episode_step(
+ opp_obs, raw_obs, opp_a, worker, base_env, episode, env_index
+ )
- self.observed_n_step_in_current_epi += 1
- else:
- self._first_fake_step_played = True
+ self.observed_n_step_in_current_epi += 1
def _get_information_from_opponent(self, agent_id, agent_ids, episode):
opp_agent_id = [one_id for one_id in agent_ids if one_id != agent_id][
@@ -243,11 +292,27 @@ def _on_episode_step(
coop_opp_simulated_action = (
self._simulate_action_from_cooperative_opponent(opp_obs)
)
+ assert (
+ len(worker.async_env.env_states) == 1
+ ), "amTFT in eval only works with one env not vector of envs"
+ assert (
+ worker.env.step_count_in_current_episode
+ == worker.async_env.env_states[
+ 0
+ ].env.step_count_in_current_episode
+ )
+ assert (
+ worker.env.step_count_in_current_episode
+ == self._base_env_at_last_step.get_unwrapped()[
+ 0
+ ].step_count_in_current_episode
+ + 1
+ )
self._update_total_debit(
last_obs,
opp_action,
worker,
- base_env,
+ self._base_env_at_last_step,
episode,
env_index,
coop_opp_simulated_action,
@@ -257,20 +322,53 @@ def _on_episode_step(
opp_action, coop_opp_simulated_action, worker, last_obs
)
+ self._base_env_at_last_step = copy.deepcopy(base_env)
+
self._to_log["n_steps_to_punish"] = self.n_steps_to_punish
self._to_log["debit_threshold"] = self.debit_threshold
def _is_punishment_planned(self):
return self.n_steps_to_punish > 0
+ def on_episode_start(self, *args, **kwargs):
+ if self.working_state in WORKING_STATES_IN_EVALUATION:
+ self._base_env_at_last_step = copy.deepcopy(kwargs["base_env"])
+ self.last_own_algo_idx_in_eval = OWN_COOP_POLICY_IDX
+ self.coop_opp_rnn_state_after_last_act = self.algorithms[
+ OPP_COOP_POLICY_IDX
+ ].get_initial_state()
+
def _simulate_action_from_cooperative_opponent(self, opp_obs):
+ if self.verbose > 1:
+ print("opp_obs for opp coop simu nonzero obs", np.nonzero(opp_obs))
+ for i, algo in enumerate(self.algorithms):
+ print("algo", i, algo)
+ self.coop_opp_rnn_state_before_last_act = (
+ self.coop_opp_rnn_state_after_last_act
+ )
(
coop_opp_simulated_action,
- _,
- self.coop_opp_extra_fetches,
- ) = self.algorithms[OPP_COOP_POLICY_IDX].compute_actions([opp_obs])
- # Returned a list
- coop_opp_simulated_action = coop_opp_simulated_action[0]
+ self.coop_opp_rnn_state_after_last_act,
+ coop_opp_extra_fetches,
+ ) = self.algorithms[OPP_COOP_POLICY_IDX].compute_single_action(
+ obs=opp_obs,
+ state=self.coop_opp_rnn_state_after_last_act,
+ )
+ if self.verbose > 1:
+ print(
+ coop_opp_simulated_action,
+ "coop_opp_extra_fetches",
+ coop_opp_extra_fetches,
+ )
+ print(
+ "state before simu coop opp",
+ self.coop_opp_rnn_state_before_last_act,
+ )
+ print(
+ "state after simu coop opp",
+ self.coop_opp_rnn_state_after_last_act,
+ )
+
return coop_opp_simulated_action
def _update_total_debit(
@@ -283,8 +381,19 @@ def _update_total_debit(
env_index,
coop_opp_simulated_action,
):
-
+ if self.verbose > 1:
+ print(
+ self.own_policy_id,
+ self.config[restore.LOAD_FROM_CONFIG_KEY][0].split("/")[-5],
+ )
if coop_opp_simulated_action != opp_action:
+ if self.verbose > 0:
+ print(
+ self.own_policy_id,
+ "coop_opp_simulated_action != opp_action:",
+ coop_opp_simulated_action,
+ opp_action,
+ )
if (
worker.env.step_count_in_current_episode
>= worker.env.max_steps
@@ -319,11 +428,6 @@ def _update_total_debit(
)
if coop_opp_simulated_action != opp_action:
if self.verbose > 0:
- print(
- "coop_opp_simulated_action != opp_action:",
- coop_opp_simulated_action,
- opp_action,
- )
print(f"debit {debit}")
print(
f"self.total_debit {self.total_debit}, previous was {tmp}"
@@ -331,8 +435,8 @@ def _update_total_debit(
def _is_starting_new_punishment_required(self, manual_threshold=None):
if manual_threshold is not None:
- return self.total_debit > manual_threshold
- return self.total_debit > self.debit_threshold
+ return self.total_debit >= manual_threshold
+ return self.total_debit >= self.debit_threshold
def _plan_punishment(
self, opp_action, coop_opp_simulated_action, worker, last_obs
@@ -355,7 +459,11 @@ def _plan_punishment(
print(f"reset self.total_debit to 0 since planned punishement")
def on_episode_end(self, *args, **kwargs):
- if self.working_state not in WORKING_STATES_IN_EVALUATION:
+ self._defensive_check_observed_n_opp_moves(*args, **kwargs)
+ self._if_in_eval_reset_debit_and_punish()
+
+ def _defensive_check_observed_n_opp_moves(self, *args, **kwargs):
+ if self.working_state in WORKING_STATES_IN_EVALUATION:
assert (
self.observed_n_step_in_current_epi
== kwargs["base_env"].get_unwrapped()[0].max_steps
@@ -365,13 +473,16 @@ def on_episode_end(self, *args, **kwargs):
f"{kwargs['base_env'].get_unwrapped()[0].max_steps} "
"steps per episodes."
)
- self.observed_n_step_in_current_epi = 0
+ self.observed_n_step_in_current_epi = 0
- if self.working_state == WORKING_STATES[2]:
+ def _if_in_eval_reset_debit_and_punish(self):
+ if self.working_state in WORKING_STATES_IN_EVALUATION:
self.total_debit = 0
self.n_steps_to_punish = 0
if self.verbose > 0:
- print(f"reset self.total_debit to 0 since end of episode")
+ logger.info(
+ "reset self.total_debit to 0 since end of " "episode"
+ )
def _compute_debit(
self,
@@ -390,6 +501,27 @@ def _compute_punishment_duration(
):
raise NotImplementedError()
+ def _get_last_rnn_states_before_rollouts(self):
+ if self.config["model"]["use_lstm"]:
+ return {
+ self.own_policy_id: self._squeezes_rnn_state(
+ self.coop_own_rnn_state_before_last_act
+ ),
+ self.opp_policy_id: self.coop_opp_rnn_state_before_last_act,
+ }
+
+ else:
+ return None
+
+ @staticmethod
+ def _squeezes_rnn_state(state):
+ return [
+ s.squeeze(0)
+ if torch and isinstance(s, torch.Tensor)
+ else np.squeeze(s, 0)
+ for s in state
+ ]
+
class AmTFTCallbacks(callbacks.PolicyCallbacks):
def on_train_result(self, trainer, *args, **kwargs):
diff --git a/marltoolbox/algos/amTFT/inversed_policy.py b/marltoolbox/algos/amTFT/inversed_policy.py
index fef207e..578dcd8 100644
--- a/marltoolbox/algos/amTFT/inversed_policy.py
+++ b/marltoolbox/algos/amTFT/inversed_policy.py
@@ -4,14 +4,15 @@
class InversedAmTFTRolloutsTorchPolicy(
- policy_using_rollouts.AmTFTRolloutsTorchPolicy):
+ policy_using_rollouts.AmTFTRolloutsTorchPolicy
+):
"""
Instead of simulating the opponent, simulate our own policy and act as
if it was the opponent.
"""
- def _init_for_rollout(self, config):
- super()._init_for_rollout(config)
+ def _init(self, config):
+ super()._init(config)
self.ag_id_rollout_reward_to_read = self.own_policy_id
@override(base_policy.AmTFTPolicyBase)
@@ -26,86 +27,3 @@ def _switch_own_and_opp(self, agent_id):
output = super()._switch_own_and_opp(agent_id)
self.use_opponent_policies = not self.use_opponent_policies
return output
-
- # @override(policy_using_rollouts.AmTFTRolloutsTorchPolicy)
- # def _select_algo_to_use_in_eval(self):
- # assert self.performing_rollouts
- #
- # if not self.use_opponent_policies:
- # if self.n_steps_to_punish == 0:
- # self.active_algo_idx = base.OPP_COOP_POLICY_IDX
- # elif self.n_steps_to_punish > 0:
- # self.active_algo_idx = base.OPP_SELFISH_POLICY_IDX
- # self.n_steps_to_punish -= 1
- # else:
- # raise ValueError("self.n_steps_to_punish can't be below zero")
- # else:
- # # assert self.performing_rollouts
- # if self.n_steps_to_punish_opponent == 0:
- # self.active_algo_idx = base.OWN_COOP_POLICY_IDX
- # elif self.n_steps_to_punish_opponent > 0:
- # self.active_algo_idx = base.OWN_SELFISH_POLICY_IDX
- # self.n_steps_to_punish_opponent -= 1
- # else:
- # raise ValueError("self.n_steps_to_punish_opp "
- # "can't be below zero")
-
- # @override(policy_using_rollouts.AmTFTRolloutsTorchPolicy)
- # def _init_for_rollout(self, config):
- # super()._init_for_rollout(config)
- # # the policies stored as opponent_policies are our own policy
- # # (not the opponent's policies)
- # self.use_opponent_policies = False
-
- # @override(policy_using_rollouts.AmTFTRolloutsTorchPolicy)
- # def _prepare_to_perform_virtual_rollouts_in_env(self, worker):
- # outputs = super()._prepare_to_perform_virtual_rollouts_in_env(
- # worker)
- # # the policies stored as opponent_policies are our own policy
- # # (not the opponent's policies)
- # self.use_opponent_policies = True
- # return outputs
-
- # @override(policy_using_rollouts.AmTFTRolloutsTorchPolicy)
- # def _stop_performing_virtual_rollouts_in_env(self, n_steps_to_punish):
- # super()._stop_performing_virtual_rollouts_in_env(n_steps_to_punish)
- # # the policies stored as opponent_policies are our own policy
- # # (not the opponent's policies)
- # self.use_opponent_policies = True
-
- # @override(policy_using_rollouts.AmTFTRolloutsTorchPolicy)
- # def compute_actions(
- # self,
- # obs_batch: Union[List[TensorType], TensorType],
- # state_batches: Optional[List[TensorType]] = None,
- # prev_action_batch: Union[List[TensorType], TensorType] = None,
- # prev_reward_batch: Union[List[TensorType], TensorType] = None,
- # info_batch: Optional[Dict[str, list]] = None,
- # episodes: Optional[List["MultiAgentEpisode"]] = None,
- # explore: Optional[bool] = None,
- # timestep: Optional[int] = None,
- # **kwargs) -> \
- # Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
- #
- # # Option to overwrite action during internal rollouts
- # if self.use_opponent_policies:
- # if len(self.overwrite_action) > 0:
- # actions, state_out, extra_fetches = \
- # self.overwrite_action.pop(0)
- # if self.verbose > 1:
- # print("overwrite actions", actions, type(actions))
- # return actions, state_out, extra_fetches
- #
- # return super().compute_actions(
- # obs_batch, state_batches, prev_action_batch, prev_reward_batch,
- # info_batch, episodes, explore, timestep, **kwargs)
-
-# debit = self._compute_debit(
-# last_obs, opp_action, worker, base_env,
-# episode, env_index, coop_opp_simulated_action)
-
-# self.n_steps_to_punish = self._compute_punishment_duration(
-# opp_action,
-# coop_opp_simulated_action,
-# worker,
-# last_obs)
diff --git a/marltoolbox/algos/amTFT/level_1_exploiter.py b/marltoolbox/algos/amTFT/level_1_exploiter.py
index c395ded..7f5c52d 100644
--- a/marltoolbox/algos/amTFT/level_1_exploiter.py
+++ b/marltoolbox/algos/amTFT/level_1_exploiter.py
@@ -3,6 +3,7 @@
from ray.rllib import SampleBatch
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override
+from ray.rllib.utils.threading import with_lock
from marltoolbox.algos import hierarchical, amTFT
from marltoolbox.algos.amTFT import base
@@ -32,7 +33,7 @@ def __init__(self, observation_space, action_space, config, **kwargs):
self.exploiter_activated = config.get("exploiter_activated")
if restore.LOAD_FROM_CONFIG_KEY in config.keys():
- restore.after_init_load_policy_checkpoint(self)
+ restore.before_loss_init_load_policy_checkpoint(self)
if self.working_state in base.WORKING_STATES_IN_EVALUATION:
if self.exploiter_activated:
@@ -48,12 +49,16 @@ def working_state(self, value):
def on_episode_start(self, worker, *args, **kwargs):
self.worker = worker
+ self.simulated_opponent.on_episode_start(*args, **kwargs)
def on_episode_end(self, *args, **kwargs):
self.simulated_opponent.on_episode_end(*args, **kwargs)
- def on_observation_fn(self, *args, **kwargs):
- self.simulated_opponent.on_observation_fn(*args, **kwargs)
+ def on_observation_fn(self, own_new_obs, opp_new_obs, both_new_raw_obs):
+ self.simulated_opponent.on_observation_fn(
+ own_new_obs, opp_new_obs, both_new_raw_obs
+ )
+ self.both_new_raw_obs = both_new_raw_obs
def on_episode_step(self, *args, **kwargs):
self.simulated_opponent.on_episode_step(*args, **kwargs)
@@ -78,12 +83,14 @@ def _train_dual_policies(self, samples: SampleBatch):
)
return learner_stats
+ @with_lock
@override(level_1.Level1ExploiterTorchPolicy)
- def compute_actions(self, *args, **kwargs):
-
+ def _compute_action_helper(self, *args, **kwargs):
# use the simulated amTFT opponent when performing rollouts
if self.simulated_opponent.performing_rollouts:
- return self.simulated_opponent.compute_actions(*args, **kwargs)
+ return self.simulated_opponent._compute_action_helper(
+ *args, **kwargs
+ )
self._to_log["use_simulated_opponent"] = 0.0
self._to_log["use_exploiter_cooperating"] = 0.0
@@ -93,16 +100,19 @@ def compute_actions(self, *args, **kwargs):
if self.working_state not in base.WORKING_STATES_IN_EVALUATION:
self._to_log["use_simulated_opponent"] = 1.0
self.active_algo_idx = self.LEVEL_1_POLICY_IDX
- return self.simulated_opponent.compute_actions(*args, **kwargs)
+ return self.simulated_opponent._compute_action_helper(
+ *args, **kwargs
+ )
# Use the simulated amTFT opponent when not using the exploiter
if not self.exploiter_activated:
self._to_log["use_simulated_opponent"] = 1.0
self.active_algo_idx = self.LEVEL_1_POLICY_IDX
- return self.simulated_opponent.compute_actions(*args, **kwargs)
+ return self.simulated_opponent._compute_action_helper(
+ *args, **kwargs
+ )
- outputs = super().compute_actions(*args, **kwargs)
- return outputs
+ return super()._compute_action_helper(*args, **kwargs)
def _is_cooperation_needed_to_prevent_punishement(
self, selfish_action_tuple, *args, **kwargs
@@ -117,9 +127,14 @@ def _lookahead_for_opponent_punishing(
self, selfish_action, *args, **kwargs
) -> bool:
selfish_action = selfish_action[0]
- last_obs = kwargs["episodes"][0]._agent_to_last_raw_obs
selfish_step_debit = self.simulated_opponent._compute_debit(
- last_obs, selfish_action, self.worker, None, None, None, None
+ self.both_new_raw_obs,
+ selfish_action,
+ self.worker,
+ None,
+ None,
+ None,
+ None,
)
if (
"expl_min_anticipated_gain" not in self._to_log.keys()
diff --git a/marltoolbox/algos/amTFT/policy_using_rollouts.py b/marltoolbox/algos/amTFT/policy_using_rollouts.py
index d6ba59b..0cc5cb6 100644
--- a/marltoolbox/algos/amTFT/policy_using_rollouts.py
+++ b/marltoolbox/algos/amTFT/policy_using_rollouts.py
@@ -1,15 +1,13 @@
-from typing import List, Union, Optional, Dict, Tuple
-
import numpy as np
-from ray.rllib.evaluation import MultiAgentEpisode
-from ray.rllib.utils.typing import TensorType
+from ray.rllib.utils import override
+from ray.rllib.utils.threading import with_lock
from marltoolbox.algos.amTFT.base import (
OWN_COOP_POLICY_IDX,
OWN_SELFISH_POLICY_IDX,
OPP_SELFISH_POLICY_IDX,
OPP_COOP_POLICY_IDX,
- WORKING_STATES,
+ WORKING_STATES_IN_EVALUATION,
)
from marltoolbox.algos.amTFT.base_policy import AmTFTPolicyBase
from marltoolbox.utils import rollout
@@ -18,9 +16,9 @@
class AmTFTRolloutsTorchPolicy(AmTFTPolicyBase):
def __init__(self, observation_space, action_space, config, **kwargs):
super().__init__(observation_space, action_space, config, **kwargs)
- self._init_for_rollout(self.config)
+ self._init(self.config)
- def _init_for_rollout(self, config):
+ def _init(self, config):
self.last_k = config["last_k"]
self.use_opponent_policies = False
self.rollout_length = config["rollout_length"]
@@ -31,49 +29,36 @@ def _init_for_rollout(self, config):
self.opp_policy_id = config["opp_policy_id"]
self.n_steps_to_punish_opponent = 0
self.ag_id_rollout_reward_to_read = self.opp_policy_id
+ self.last_opp_algo_idx_in_rollout = OPP_COOP_POLICY_IDX
+ self.last_own_algo_idx_in_rollout = OWN_COOP_POLICY_IDX
+ self.use_short_debit_rollout = config.get(
+ "use_short_debit_rollout", False
+ )
- # Don't support LSTM (at least because of action
- # overwriting needed in the rollouts)
- if "model" in config.keys():
- if "use_lstm" in config["model"].keys():
- assert not config["model"]["use_lstm"]
-
- def compute_actions(
- self,
- obs_batch: Union[List[TensorType], TensorType],
- state_batches: Optional[List[TensorType]] = None,
- prev_action_batch: Union[List[TensorType], TensorType] = None,
- prev_reward_batch: Union[List[TensorType], TensorType] = None,
- info_batch: Optional[Dict[str, list]] = None,
- episodes: Optional[List["MultiAgentEpisode"]] = None,
- explore: Optional[bool] = None,
- timestep: Optional[int] = None,
- **kwargs,
- ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
-
+ @with_lock
+ @override(AmTFTPolicyBase)
+ def _compute_action_helper(
+ self, input_dict, state_batches, seq_lens, explore, timestep
+ ):
# Option to overwrite action during internal rollouts
if self.use_opponent_policies:
if len(self.overwrite_action) > 0:
actions, state_out, extra_fetches = self.overwrite_action.pop(
0
)
+ self.last_opp_algo_idx_in_rollout = OPP_SELFISH_POLICY_IDX
if self.verbose > 1:
- print("overwritten actions", actions, type(actions))
+ print(
+ "overwritten actions", actions, "state_out", state_out
+ )
return actions, state_out, extra_fetches
- return super().compute_actions(
- obs_batch,
- state_batches,
- prev_action_batch,
- prev_reward_batch,
- info_batch,
- episodes,
- explore,
- timestep,
- **kwargs,
+ return super()._compute_action_helper(
+ input_dict, state_batches, seq_lens, explore, timestep
)
- def _select_algo_to_use_in_eval(self):
+ @override(AmTFTPolicyBase)
+ def _select_algo_to_use_in_eval(self, state_batches):
if not self.use_opponent_policies:
if self.n_steps_to_punish == 0:
self.active_algo_idx = OWN_COOP_POLICY_IDX
@@ -82,6 +67,16 @@ def _select_algo_to_use_in_eval(self):
self.n_steps_to_punish -= 1
else:
raise ValueError("self.n_steps_to_punish can't be below zero")
+
+ if self.performing_rollouts:
+ state_batches = self._check_for_rnn_state_reset(
+ state_batches, "last_own_algo_idx_in_rollout"
+ )
+ else:
+ state_batches = self._check_for_rnn_state_reset(
+ state_batches, "last_own_algo_idx_in_eval"
+ )
+
else:
assert self.performing_rollouts
if self.n_steps_to_punish_opponent == 0:
@@ -94,6 +89,17 @@ def _select_algo_to_use_in_eval(self):
"self.n_steps_to_punish_opp " "can't be below zero"
)
+ state_batches = self._check_for_rnn_state_reset(
+ state_batches, "last_opp_algo_idx_in_rollout"
+ )
+
+ return state_batches
+
+ @override(AmTFTPolicyBase)
+ def _track_last_coop_rnn_state(self, state_batches):
+ if not self.performing_rollouts and not self.use_opponent_policies:
+ super()._track_last_coop_rnn_state(state_batches)
+
def _on_episode_step(
self,
opp_obs,
@@ -104,16 +110,16 @@ def _on_episode_step(
episode,
env_index,
):
- if not self.performing_rollouts:
- super()._on_episode_step(
- opp_obs,
- last_obs,
- opp_action,
- worker,
- base_env,
- episode,
- env_index,
- )
+ assert not self.performing_rollouts
+ super()._on_episode_step(
+ opp_obs,
+ last_obs,
+ opp_action,
+ worker,
+ base_env,
+ episode,
+ env_index,
+ )
def _compute_debit(
self,
@@ -126,11 +132,13 @@ def _compute_debit(
coop_opp_simulated_action,
):
approximated_debit = self._compute_debit_using_rollouts(
- last_obs, opp_action, worker
+ last_obs, opp_action, worker, base_env
)
return approximated_debit
- def _compute_debit_using_rollouts(self, last_obs, opp_action, worker):
+ def _compute_debit_using_rollouts(
+ self, last_obs, opp_action, worker, base_env
+ ):
(
n_steps_to_punish,
policy_map,
@@ -138,6 +146,9 @@ def _compute_debit_using_rollouts(self, last_obs, opp_action, worker):
) = self._prepare_to_perform_virtual_rollouts_in_env(worker)
# Cooperative rollouts
+ if self.verbose > 1:
+ print("Compute debit")
+ print("Cooperative rollouts")
(
mean_total_reward_for_totally_coop_opp,
_,
@@ -148,8 +159,12 @@ def _compute_debit_using_rollouts(self, last_obs, opp_action, worker):
partially_coop=False,
opp_action=None,
last_obs=last_obs,
+ rollout_length=1 if self.use_short_debit_rollout else None,
+ base_env=base_env,
)
# Cooperative rollouts with first action as the real one
+ if self.verbose > 1:
+ print("Parital cooperative rollouts")
(
mean_total_reward_for_partially_coop_opp,
_,
@@ -160,6 +175,8 @@ def _compute_debit_using_rollouts(self, last_obs, opp_action, worker):
partially_coop=True,
opp_action=opp_action,
last_obs=last_obs,
+ rollout_length=1 if self.use_short_debit_rollout else None,
+ base_env=base_env,
)
if self.verbose > 0:
@@ -489,24 +506,30 @@ def _compute_opp_mean_total_reward(
opp_action,
last_obs,
k_to_explore=0,
+ rollout_length=None,
+ base_env=None,
):
opp_total_rewards = []
for i in range(self.n_rollout_replicas // 2):
+ self.last_opp_algo_idx_in_rollout = OPP_COOP_POLICY_IDX
+ self.last_own_algo_idx_in_rollout = OWN_COOP_POLICY_IDX
self.n_steps_to_punish = k_to_explore
self.n_steps_to_punish_opponent = k_to_explore
if partially_coop:
- assert len(self.overwrite_action) == 0
- self.overwrite_action = [
- (np.array([opp_action]), [], {}),
- ]
+ self._set_overwrite_action(opp_action)
+ last_rnn_states = self._get_last_rnn_states_before_rollouts()
coop_rollout = rollout.internal_rollout(
worker,
- num_steps=self.rollout_length,
+ num_steps=self.rollout_length
+ if rollout_length is None
+ else rollout_length,
policy_map=policy_map,
last_obs=last_obs,
policy_agent_mapping=policy_agent_mapping,
reset_env_before=False,
num_episodes=1,
+ last_rnn_states=last_rnn_states,
+ base_env=base_env,
)
assert (
coop_rollout._num_episodes == 1
@@ -536,8 +559,33 @@ def _compute_opp_mean_total_reward(
opp_mean_total_reward = sum(opp_total_rewards) / len(opp_total_rewards)
return opp_mean_total_reward, n_steps_played
+ def _set_overwrite_action(self, opp_action):
+ assert len(self.overwrite_action) == 0
+ if self.config["model"]["use_lstm"]:
+ # When we play the real opponent action, don't break the sequence
+ # of rnn state for the cooperative opponent
+ # rnn_state = [[el] for el in self.coop_opp_rnn_state_after_last_act]
+ rnn_state = [
+ self._get_initial_rnn_state(None)
+ for el in self.coop_opp_rnn_state_after_last_act
+ ]
+
+ else:
+ rnn_state = []
+ self.overwrite_action = [
+ (np.array([opp_action]), rnn_state, {}),
+ ]
+
+ @override(AmTFTPolicyBase)
def on_episode_end(self, *args, **kwargs):
- if self.working_state == WORKING_STATES[2]:
- self.total_debit = 0
- self.n_steps_to_punish = 0
+ assert not self.performing_rollouts
+ if self.working_state in WORKING_STATES_IN_EVALUATION:
self.n_steps_to_punish_opponent = 0
+ super().on_episode_end(*args, **kwargs)
+
+ @override(AmTFTPolicyBase)
+ def on_episode_start(self, *args, **kwargs):
+ assert not self.performing_rollouts
+ if self.working_state in WORKING_STATES_IN_EVALUATION:
+ self.last_own_algo_idx_in_eval = OWN_COOP_POLICY_IDX
+ super().on_episode_start(*args, **kwargs)
diff --git a/marltoolbox/algos/amTFT/train_helper.py b/marltoolbox/algos/amTFT/train_helper.py
index 9d6109b..34dafd8 100644
--- a/marltoolbox/algos/amTFT/train_helper.py
+++ b/marltoolbox/algos/amTFT/train_helper.py
@@ -5,11 +5,12 @@
from ray import tune
from ray.rllib.agents.dqn import DQNTrainer
+from marltoolbox import utils
from marltoolbox.algos import amTFT
from marltoolbox.scripts.aggregate_and_plot_tensorboard_data import (
add_summary_plots,
)
-from marltoolbox.utils import miscellaneous, restore
+from marltoolbox.utils import miscellaneous, restore, tune_analysis
def train_amtft(
@@ -117,10 +118,10 @@ def _get_plot_keys(plot_keys, plot_assemblage_tags):
def _extract_selfish_policies_checkpoints(tune_analysis_selfish_policies):
- checkpoints = miscellaneous.extract_checkpoints(
+ checkpoints = restore.extract_checkpoints_from_tune_analysis(
tune_analysis_selfish_policies
)
- seeds = miscellaneous.extract_config_values_from_tune_analysis(
+ seeds = utils.tune_analysis.extract_config_values_from_tune_analysis(
tune_analysis_selfish_policies, "seed"
)
seed_to_checkpoint = {}
diff --git a/marltoolbox/algos/amTFT/weights_exchanger.py b/marltoolbox/algos/amTFT/weights_exchanger.py
index c296a05..380e812 100644
--- a/marltoolbox/algos/amTFT/weights_exchanger.py
+++ b/marltoolbox/algos/amTFT/weights_exchanger.py
@@ -44,6 +44,7 @@ def _share_weights_during_training(trainer):
local_policy_map
)
if in_training:
+ print("amTFT _share_weights_during_training")
WeightsExchanger._check_only_amTFT_policies(local_policy_map)
policies_weights = trainer.get_weights()
policies_weights = (
diff --git a/marltoolbox/algos/augmented_dqn.py b/marltoolbox/algos/augmented_dqn.py
index 0cbf986..e5bbe7c 100644
--- a/marltoolbox/algos/augmented_dqn.py
+++ b/marltoolbox/algos/augmented_dqn.py
@@ -1,7 +1,6 @@
-from typing import Dict
-
import torch
import torch.nn.functional as F
+from ray.rllib.agents.dqn import dqn_torch_policy
from ray.rllib.agents.dqn.dqn_tf_policy import PRIO_WEIGHTS
from ray.rllib.agents.dqn.dqn_torch_policy import ComputeTDErrorMixin
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy
@@ -11,6 +10,7 @@
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.torch_ops import FLOAT_MIN
from ray.rllib.utils.typing import TensorType
+from typing import Dict
from marltoolbox.utils import log, optimizers, policy
@@ -20,36 +20,34 @@ def build_q_losses_wt_additional_logs(
) -> TensorType:
"""
Copy of build_q_losses with additional values saved into the policy
- Made only 2 changes, see in comments.
+ Made only 1 change, see in comments.
"""
config = policy.config
# Q-network evaluation.
- q_t, q_logits_t, q_probs_t = compute_q_values(
+ q_t, q_logits_t, q_probs_t, _ = compute_q_values(
policy,
- policy.q_model,
- train_batch[SampleBatch.CUR_OBS],
+ model,
+ {"obs": train_batch[SampleBatch.CUR_OBS]},
explore=False,
is_training=True,
)
- # Addition 1 out of 2
- policy.last_q_t = q_t.clone()
-
# Target Q-network evaluation.
- q_tp1, q_logits_tp1, q_probs_tp1 = compute_q_values(
+ q_tp1, _, q_probs_tp1, _ = compute_q_values(
policy,
policy.target_q_model,
- train_batch[SampleBatch.NEXT_OBS],
+ {"obs": train_batch[SampleBatch.NEXT_OBS]},
explore=False,
is_training=True,
)
- # Addition 2 out of 2
- policy.last_target_q_t = q_tp1.clone()
+ # Only additions
+ policy.last_q_t = q_t.clone().detach()
+ policy.last_target_q_t = q_tp1.clone().detach()
# Q scores for actions which we know were selected in the given state.
one_hot_selection = F.one_hot(
- train_batch[SampleBatch.ACTIONS], policy.action_space.n
+ train_batch[SampleBatch.ACTIONS].long(), policy.action_space.n
)
q_t_selected = torch.sum(
torch.where(
@@ -68,10 +66,11 @@ def build_q_losses_wt_additional_logs(
q_tp1_using_online_net,
q_logits_tp1_using_online_net,
q_dist_tp1_using_online_net,
+ _,
) = compute_q_values(
policy,
- policy.q_model,
- train_batch[SampleBatch.NEXT_OBS],
+ model,
+ {"obs": train_batch[SampleBatch.NEXT_OBS]},
explore=False,
is_training=True,
)
@@ -108,20 +107,12 @@ def build_q_losses_wt_additional_logs(
q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1
)
- if PRIO_WEIGHTS not in train_batch.keys():
- assert config["prioritized_replay"] is False
- prio_weights = torch.tensor(
- [1.0] * len(train_batch[SampleBatch.REWARDS])
- ).to(policy.device)
- else:
- prio_weights = train_batch[PRIO_WEIGHTS]
-
policy.q_loss = QLoss(
q_t_selected,
q_logits_t_selected,
q_tp1_best,
q_probs_tp1_best,
- prio_weights,
+ train_batch[PRIO_WEIGHTS],
train_batch[SampleBatch.REWARDS],
train_batch[SampleBatch.DONES].float(),
config["gamma"],
@@ -134,28 +125,25 @@ def build_q_losses_wt_additional_logs(
return policy.q_loss.loss
-def build_q_stats_wt_addtional_log(
- policy: Policy, batch
-) -> Dict[str, TensorType]:
- entropy_avg, entropy_single = log._compute_entropy_from_raw_q_values(
- policy, policy.last_q_t.clone()
- )
+def my_build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
+ q_stats = dqn_torch_policy.build_q_stats(policy, batch)
- return dict(
+ entropy_avg, _ = log.compute_entropy_from_raw_q_values(
+ policy, policy.last_q_t
+ )
+ q_stats.update(
{
"entropy_avg": entropy_avg,
- "cur_lr": policy.cur_lr,
- },
- **policy.q_loss.stats,
+ }
)
+ return q_stats
+
MyDQNTorchPolicy = DQNTorchPolicy.with_updates(
optimizer_fn=optimizers.sgd_optimizer_dqn,
loss_fn=build_q_losses_wt_additional_logs,
- stats_fn=log.augment_stats_fn_wt_additionnal_logs(
- build_q_stats_wt_addtional_log
- ),
+ stats_fn=log.augment_stats_fn_wt_additionnal_logs(my_build_q_stats),
before_init=policy.my_setup_early_mixins,
mixins=[
TargetNetworkMixin,
diff --git a/marltoolbox/algos/augmented_r2d2.py b/marltoolbox/algos/augmented_r2d2.py
new file mode 100644
index 0000000..0a46d65
--- /dev/null
+++ b/marltoolbox/algos/augmented_r2d2.py
@@ -0,0 +1,224 @@
+"""PyTorch policy class used for R2D2."""
+
+import torch
+import torch.nn.functional as F
+from ray.rllib.agents.dqn import r2d2_torch_policy
+from ray.rllib.agents.dqn.dqn_tf_policy import PRIO_WEIGHTS
+from ray.rllib.agents.dqn.dqn_torch_policy import ComputeTDErrorMixin
+from ray.rllib.agents.dqn.dqn_torch_policy import compute_q_values
+from ray.rllib.agents.dqn.r2d2_torch_policy import (
+ R2D2TorchPolicy,
+ h_function,
+ h_inverse,
+)
+from ray.rllib.agents.dqn.simple_q_torch_policy import TargetNetworkMixin
+from ray.rllib.models.modelv2 import ModelV2
+from ray.rllib.policy.policy import Policy
+from ray.rllib.policy.sample_batch import SampleBatch
+from ray.rllib.utils.torch_ops import FLOAT_MIN
+from ray.rllib.utils.torch_ops import huber_loss, sequence_mask, l2_loss
+from ray.rllib.utils.typing import TensorType
+from typing import Dict
+
+from marltoolbox.utils import log, optimizers, policy
+
+
+def my_r2d2_loss(
+ policy: Policy, model, _, train_batch: SampleBatch
+) -> TensorType:
+ """Constructs the loss for R2D2TorchPolicy.
+
+ Args:
+ policy (Policy): The Policy to calculate the loss for.
+ model (ModelV2): The Model to calculate the loss for.
+ train_batch (SampleBatch): The training data.
+
+ Returns:
+ TensorType: A single loss tensor.
+ """
+ config = policy.config
+
+ # Construct internal state inputs.
+ i = 0
+ state_batches = []
+ while "state_in_{}".format(i) in train_batch:
+ state_batches.append(train_batch["state_in_{}".format(i)])
+ i += 1
+ assert state_batches
+
+ # Q-network evaluation (at t).
+ q, _, _, _ = compute_q_values(
+ policy,
+ model,
+ train_batch,
+ state_batches=state_batches,
+ seq_lens=train_batch.get("seq_lens"),
+ explore=False,
+ is_training=True,
+ )
+
+ # Target Q-network evaluation (at t+1).
+ q_target, _, _, _ = compute_q_values(
+ policy,
+ policy.target_q_model,
+ train_batch,
+ state_batches=state_batches,
+ seq_lens=train_batch.get("seq_lens"),
+ explore=False,
+ is_training=True,
+ )
+
+ # Only additions
+ policy.last_q = q.clone().detach()
+ policy.last_q_target = q_target.clone().detach()
+
+ actions = train_batch[SampleBatch.ACTIONS].long()
+ dones = train_batch[SampleBatch.DONES].float()
+ rewards = train_batch[SampleBatch.REWARDS]
+ weights = train_batch[PRIO_WEIGHTS]
+
+ B = state_batches[0].shape[0]
+ T = q.shape[0] // B
+
+ # Q scores for actions which we know were selected in the given state.
+ one_hot_selection = F.one_hot(actions, policy.action_space.n)
+ q_selected = torch.sum(
+ torch.where(q > FLOAT_MIN, q, torch.tensor(0.0, device=policy.device))
+ * one_hot_selection,
+ 1,
+ )
+
+ if config["double_q"]:
+ best_actions = torch.argmax(q, dim=1)
+ else:
+ best_actions = torch.argmax(q_target, dim=1)
+
+ best_actions_one_hot = F.one_hot(best_actions, policy.action_space.n)
+ q_target_best = torch.sum(
+ torch.where(
+ q_target > FLOAT_MIN,
+ q_target,
+ torch.tensor(0.0, device=policy.device),
+ )
+ * best_actions_one_hot,
+ dim=1,
+ )
+
+ if config["num_atoms"] > 1:
+ raise ValueError("Distributional R2D2 not supported yet!")
+ else:
+ q_target_best_masked_tp1 = (1.0 - dones) * torch.cat(
+ [q_target_best[1:], torch.tensor([0.0], device=policy.device)]
+ )
+
+ if config["use_h_function"]:
+ h_inv = h_inverse(
+ q_target_best_masked_tp1, config["h_function_epsilon"]
+ )
+ target = h_function(
+ rewards + config["gamma"] ** config["n_step"] * h_inv,
+ config["h_function_epsilon"],
+ )
+ else:
+ target = (
+ rewards
+ + config["gamma"] ** config["n_step"]
+ * q_target_best_masked_tp1
+ )
+
+ # Seq-mask all loss-related terms.
+ seq_mask = sequence_mask(train_batch["seq_lens"], T)[:, :-1]
+ # Mask away also the burn-in sequence at the beginning.
+ burn_in = policy.config["burn_in"]
+ if burn_in > 0 and burn_in < T:
+ seq_mask[:, :burn_in] = False
+
+ num_valid = torch.sum(seq_mask)
+
+ def reduce_mean_valid(t):
+ return torch.sum(t[seq_mask]) / num_valid
+
+ # Make sure use the correct time indices:
+ # Q(t) - [gamma * r + Q^(t+1)]
+ q_selected = q_selected.reshape([B, T])[:, :-1]
+ td_error = q_selected - target.reshape([B, T])[:, :-1].detach()
+ td_error = td_error * seq_mask
+ weights = weights.reshape([B, T])[:, :-1]
+
+ return weights, td_error, q_selected, reduce_mean_valid
+
+
+def my_r2d2_loss_wt_huber_loss(
+ policy: Policy, model, _, train_batch: SampleBatch
+) -> TensorType:
+
+ weights, td_error, q_selected, reduce_mean_valid = my_r2d2_loss(
+ policy, model, _, train_batch
+ )
+
+ policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error))
+ policy._td_error = td_error.reshape([-1])
+ policy._loss_stats = {
+ "mean_q": reduce_mean_valid(q_selected),
+ "min_q": torch.min(q_selected),
+ "max_q": torch.max(q_selected),
+ "mean_td_error": reduce_mean_valid(td_error),
+ }
+
+ return policy._total_loss
+
+
+def my_r2d2_loss_wtout_huber_loss(
+ policy: Policy, model, _, train_batch: SampleBatch
+) -> TensorType:
+
+ weights, td_error, q_selected, reduce_mean_valid = my_r2d2_loss(
+ policy, model, _, train_batch
+ )
+
+ policy._total_loss = reduce_mean_valid(weights * l2_loss(td_error))
+ policy._td_error = td_error.reshape([-1])
+ policy._loss_stats = {
+ "mean_q": reduce_mean_valid(q_selected),
+ "min_q": torch.min(q_selected),
+ "max_q": torch.max(q_selected),
+ "mean_td_error": reduce_mean_valid(td_error),
+ }
+
+ return policy._total_loss
+
+
+def my_build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
+ q_stats = r2d2_torch_policy.build_q_stats(policy, batch)
+
+ entropy_avg, _ = log.compute_entropy_from_raw_q_values(
+ policy, policy.last_q.clone()
+ )
+ q_stats.update(
+ {
+ "entropy_avg": entropy_avg,
+ }
+ )
+
+ return q_stats
+
+
+MyR2D2TorchPolicy = R2D2TorchPolicy.with_updates(
+ optimizer_fn=optimizers.sgd_optimizer_dqn,
+ loss_fn=my_r2d2_loss_wt_huber_loss,
+ stats_fn=log.augment_stats_fn_wt_additionnal_logs(my_build_q_stats),
+ before_init=policy.my_setup_early_mixins,
+ mixins=[
+ TargetNetworkMixin,
+ ComputeTDErrorMixin,
+ policy.MyLearningRateSchedule,
+ ],
+)
+
+MyR2D2TorchPolicyWtMSELoss = MyR2D2TorchPolicy.with_updates(
+ loss_fn=my_r2d2_loss_wtout_huber_loss,
+)
+
+MyAdamDQNTorchPolicy = MyR2D2TorchPolicy.with_updates(
+ optimizer_fn=optimizers.adam_optimizer_dqn,
+)
diff --git a/marltoolbox/algos/exploiters/dual_behavior.py b/marltoolbox/algos/exploiters/dual_behavior.py
index 31c8cfd..13e1527 100644
--- a/marltoolbox/algos/exploiters/dual_behavior.py
+++ b/marltoolbox/algos/exploiters/dual_behavior.py
@@ -4,6 +4,7 @@
from ray.rllib import SampleBatch
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override
+from ray.rllib.utils.threading import with_lock
from marltoolbox.algos import hierarchical
from marltoolbox.utils import postprocessing
@@ -34,6 +35,7 @@ def __init__(self, observation_space, action_space, config, **kwargs):
super().__init__(observation_space, action_space, config, **kwargs)
self.active_algo_idx = self.SELFISH_POLICY_IDX
+ @with_lock
@override(hierarchical.HierarchicalTorchPolicy)
def _learn_on_batch(self, samples: SampleBatch):
return self._train_dual_policies(samples)
diff --git a/marltoolbox/algos/exploiters/level_1.py b/marltoolbox/algos/exploiters/level_1.py
index cb2807a..b5bd8d8 100644
--- a/marltoolbox/algos/exploiters/level_1.py
+++ b/marltoolbox/algos/exploiters/level_1.py
@@ -1,10 +1,11 @@
import logging
-from typing import TYPE_CHECKING
from ray.rllib import SampleBatch
-from ray.rllib.policy import Policy
+from ray.rllib.policy import TorchPolicy
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override
+from ray.rllib.utils.threading import with_lock
+from typing import TYPE_CHECKING
from marltoolbox.algos.exploiters import dual_behavior
@@ -17,7 +18,7 @@
dual_behavior.DEFAULT_CONFIG,
{
"lookahead_n_times": int(1),
- }
+ },
)
@@ -34,6 +35,7 @@ class Level1ExploiterTorchPolicy(dual_behavior.DualBehaviorTorchPolicy):
"""
+
LEVEL_1_POLICY_IDX = 2
def __init__(self, observation_space, action_space, config, **kwargs):
@@ -50,13 +52,14 @@ def _learn_on_batch(self, samples: SampleBatch):
def _train_level_1_policy(self, learner_stats, samples):
raise NotImplementedError()
- @override(Policy)
- def compute_actions(self, *args, **kwargs):
-
+ @with_lock
+ @override(TorchPolicy)
+ def _compute_action_helper(self, *args, **kwargs):
selfish_action_tuple = self._compute_selfish_action(*args, **kwargs)
coop_needed = self._is_cooperation_needed_to_prevent_punishement(
- selfish_action_tuple, *args, **kwargs)
+ selfish_action_tuple, *args, **kwargs
+ )
self._to_log["coop_needed"] = coop_needed
@@ -64,7 +67,7 @@ def compute_actions(self, *args, **kwargs):
self._to_log["use_exploiter_cooperating"] = 1.0
self._to_log["use_exploiter_selfish"] = 0.0
self.active_algo_idx = self.COOP_POLICY_IDX
- return super().compute_actions(*args, **kwargs)
+ return super()._compute_action_helper(*args, **kwargs)
else:
self._to_log["use_exploiter_cooperating"] = 0.0
self._to_log["use_exploiter_selfish"] = 1.0
@@ -74,12 +77,13 @@ def compute_actions(self, *args, **kwargs):
def _compute_selfish_action(self, *args, **kwargs):
tmp_active_algo_idx = self.active_algo_idx
self.active_algo_idx = self.SELFISH_POLICY_IDX
- selfish_action_tuple = super().compute_actions(*args, **kwargs)
+ selfish_action_tuple = super()._compute_action_helper(*args, **kwargs)
self.active_algo_idx = tmp_active_algo_idx
return selfish_action_tuple
def _is_cooperation_needed_to_prevent_punishement(
- self, selfish_action_tuple, *args, **kwargs):
+ self, selfish_action_tuple, *args, **kwargs
+ ):
selfish_action = selfish_action_tuple[0]
@@ -87,23 +91,26 @@ def _is_cooperation_needed_to_prevent_punishement(
is_going_to_punish_at_next_step = False
for _ in range(self.lookahead_n_times):
- is_going_to_punish_at_next_step = \
+ is_going_to_punish_at_next_step = (
self._lookahead_for_opponent_punishing(
- selfish_action,
- *args,
- **kwargs)
+ selfish_action, *args, **kwargs
+ )
+ )
if is_going_to_punish_at_next_step:
break
self._to_log["simu_opp_is_punishing"] = is_punishing
- self._to_log["simu_opp_is_going_to_punish_at_next_step"] = \
- is_going_to_punish_at_next_step
- return (is_punishing or is_going_to_punish_at_next_step)
+ self._to_log[
+ "simu_opp_is_going_to_punish_at_next_step"
+ ] = is_going_to_punish_at_next_step
+ return is_punishing or is_going_to_punish_at_next_step
def _lookahead_for_opponent_punishing(
- self, selfish_action, *args, **kwargs) -> bool:
+ self, selfish_action, *args, **kwargs
+ ) -> bool:
raise NotImplementedError()
+
# class ObserveOpponentCallbacks(DefaultCallbacks):
#
# def on_postprocess_trajectory(
diff --git a/marltoolbox/algos/hierarchical.py b/marltoolbox/algos/hierarchical.py
index aa6ab30..c7c3880 100644
--- a/marltoolbox/algos/hierarchical.py
+++ b/marltoolbox/algos/hierarchical.py
@@ -1,13 +1,14 @@
import copy
import logging
-from typing import List, Union, Optional, Dict, Tuple, Iterable
+
+from ray.rllib.utils.threading import with_lock
+from typing import List, Union, Iterable
import torch
from ray import rllib
-from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
-from ray.rllib.utils.typing import TensorType
+from ray.rllib.utils.annotations import override
from marltoolbox.utils import log
@@ -41,9 +42,9 @@ def __init__(
updated_config.update(nested_config["config_update"])
if nested_config["Policy_class"] is None:
raise ValueError(
- f"You must specify the classes for the nested Policies "
- f'in config["nested_config"]["Policy_class"] '
- f'current value is {nested_config["Policy_class"]}.'
+ "You must specify the classes for the nested Policies "
+ 'in config["nested_config"]["Policy_class"]. '
+ f"nested_config: {nested_config}."
)
Policy = nested_config["Policy_class"]
policy = Policy(
@@ -70,6 +71,21 @@ def __init__(
self._to_log = {}
self._already_printed_warnings = []
+ self._merge_all_view_requirements()
+
+ def _merge_all_view_requirements(self):
+ self.view_requirements = {}
+ for algo in self.algorithms:
+ for k, v in algo.view_requirements.items():
+ self._add_in_view_requirement_or_check_are_equals(k, v)
+
+ def _add_in_view_requirement_or_check_are_equals(self, k, v):
+ if k not in self.view_requirements.keys():
+ self.view_requirements[k] = v
+ else:
+ assert vars(self.view_requirements[k]) == vars(
+ v
+ ), f"{vars(self.view_requirements[k])} must equal {vars(v)}"
def __getattribute__(self, attr):
"""
@@ -87,10 +103,8 @@ def __getattribute__(self, attr):
f"printing this message."
)
logger.info(msg)
- # from warnings import warn
- # warn(msg)
- return object.__getattribute__(
- self.algorithms[self.active_algo_idx], attr
+ return self.algorithms[self.active_algo_idx].__getattribute__(
+ attr
)
except AttributeError as secondary:
raise type(initial)(f"{initial.args} and {secondary.args}")
@@ -113,7 +127,6 @@ def to_log(self, value):
for algo in self.algorithms:
if hasattr(algo, "to_log"):
algo.to_log = {}
- # setattr(algo, "to_log", {})
self._to_log = value
@@ -125,6 +138,14 @@ def model(self):
def model(self, value):
self.algorithms[self.active_algo_idx].model = value
+ @property
+ def exploration(self):
+ return self.algorithms[self.active_algo_idx].exploration
+
+ @exploration.setter
+ def exploration(self, value):
+ self.algorithms[self.active_algo_idx].exploration = value
+
@property
def dist_class(self):
return self.algorithms[self.active_algo_idx].dist_class
@@ -142,9 +163,11 @@ def global_timestep(self, value):
for algo in self.algorithms:
algo.global_timestep = value
+ @override(rllib.policy.Policy)
def on_global_var_update(self, global_vars):
for algo in self.algorithms:
algo.on_global_var_update(global_vars)
+ super().on_global_var_update(global_vars)
@property
def update_target(self):
@@ -155,12 +178,14 @@ def nested_update_target():
return nested_update_target
+ @override(rllib.policy.TorchPolicy)
def get_weights(self):
return {
self.nested_key(i): algo.get_weights()
for i, algo in enumerate(self.algorithms)
}
+ @override(rllib.policy.TorchPolicy)
def set_weights(self, weights):
for i, algo in enumerate(self.algorithms):
algo.set_weights(weights[self.nested_key(i)])
@@ -168,27 +193,22 @@ def set_weights(self, weights):
def nested_key(self, i):
return f"nested_{i}"
- def compute_actions(
- self,
- obs_batch: Union[List[TensorType], TensorType],
- state_batches: Optional[List[TensorType]] = None,
- prev_action_batch: Union[List[TensorType], TensorType] = None,
- prev_reward_batch: Union[List[TensorType], TensorType] = None,
- info_batch: Optional[Dict[str, list]] = None,
- episodes: Optional[List["MultiAgentEpisode"]] = None,
- explore: Optional[bool] = None,
- timestep: Optional[int] = None,
- **kwargs,
- ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
-
+ @with_lock
+ @override(rllib.policy.TorchPolicy)
+ def _compute_action_helper(
+ self, input_dict, state_batches, seq_lens, explore, timestep
+ ):
actions, state_out, extra_fetches = self.algorithms[
self.active_algo_idx
- ].compute_actions(obs_batch)
+ ]._compute_action_helper(
+ input_dict, state_batches, seq_lens, explore, timestep
+ )
return actions, state_out, extra_fetches
+ @with_lock
def learn_on_batch(self, samples: SampleBatch):
- self._update_lr_in_all_optimizers()
+
stats = self._learn_on_batch(samples)
self._log_learning_rates()
return stats
@@ -217,22 +237,12 @@ def optimizer(
The local PyTorch optimizer(s) to use for this Policy.
"""
- # TODO find a clean solution to update the LR when using a LearningRateSchedule
- # TODO this will probably no more be needed when moving to RLLib>1.0.0
- for algo in self.algorithms:
- if hasattr(algo, "cur_lr"):
- for opt in algo._optimizers:
- for p in opt.param_groups:
- p["lr"] = algo.cur_lr
-
all_optimizers = []
for algo_n, algo in enumerate(self.algorithms):
opt = algo.optimizer()
all_optimizers.extend(opt)
return all_optimizers
- # TODO Move this in helper functions
-
def postprocess_trajectory(
self, sample_batch, other_agent_batches=None, episode=None
):
@@ -245,14 +255,13 @@ def _log_learning_rates(self):
Use to log LR to check that they are really updated as configured.
"""
for algo_idx, algo in enumerate(self.algorithms):
- self.to_log[f"algo{algo_idx}"] = log._log_learning_rate(algo)
+ self.to_log[f"algo{algo_idx}"] = log.log_learning_rate(algo)
def set_state(self, state: object) -> None:
state = state.copy() # shallow copy
# Set optimizer vars first.
optimizer_vars = state.pop("_optimizer_variables", None)
if optimizer_vars:
- print("self", self)
assert len(optimizer_vars) == len(
self._optimizers
), f"{len(optimizer_vars)} {len(self._optimizers)}"
@@ -260,3 +269,12 @@ def set_state(self, state: object) -> None:
o.load_state_dict(s)
# Then the Policy's (NN) weights.
super().set_state(state)
+
+ def _get_dummy_batch_from_view_requirements(
+ self, batch_size: int = 1
+ ) -> SampleBatch:
+ dummy_sample_batch = super()._get_dummy_batch_from_view_requirements(
+ batch_size
+ )
+ dummy_sample_batch[dummy_sample_batch.DONES][-1] = True
+ return dummy_sample_batch
diff --git a/marltoolbox/algos/lola/__init__.py b/marltoolbox/algos/lola/__init__.py
index 7e97b24..553ba18 100644
--- a/marltoolbox/algos/lola/__init__.py
+++ b/marltoolbox/algos/lola/__init__.py
@@ -1,3 +1,3 @@
from marltoolbox.algos.lola.train_cg_tune_class_API import LOLAPGCG
-from marltoolbox.algos.lola.train_exact_tune_class_API import LOLAExact
+from marltoolbox.algos.lola.train_exact_tune_class_API import LOLAExactTrainer
from marltoolbox.algos.lola.train_pg_tune_class_API import LOLAPGMatrice
diff --git a/marltoolbox/algos/lola/networks.py b/marltoolbox/algos/lola/networks.py
index dd01eaf..d17c8f8 100644
--- a/marltoolbox/algos/lola/networks.py
+++ b/marltoolbox/algos/lola/networks.py
@@ -100,7 +100,7 @@ def __init__(
self.sample_reward_bis = tf.placeholder(
shape=[None, trace_length],
dtype=tf.float32,
- name="sample_reward",
+ name="sample_reward_bis",
)
self.j = tf.placeholder(shape=[None], dtype=tf.float32, name="j")
@@ -239,6 +239,7 @@ def __init__(
shape=[None, 1], dtype=tf.float32, name="next_value"
)
self.next_v = tf.matmul(self.next_value, self.gamma_array_inverse)
+
if use_critic:
self.target = self.sample_reward_bis + self.next_v
else:
@@ -325,6 +326,7 @@ def __init__(
entropy_coeff * self.entropy
+ weigth_decay * self.weigths_norm
) * self.loss_multiplier
+
self.updateModel = self.trainer.minimize(
total_loss, var_list=self.value_params
)
diff --git a/marltoolbox/algos/lola/train_cg_tune_class_API.py b/marltoolbox/algos/lola/train_cg_tune_class_API.py
index 7a5f903..ac93645 100644
--- a/marltoolbox/algos/lola/train_cg_tune_class_API.py
+++ b/marltoolbox/algos/lola/train_cg_tune_class_API.py
@@ -21,7 +21,7 @@
from marltoolbox.algos.lola.networks import Pnetwork
from marltoolbox.algos.lola.utils import get_monte_carlo, make_cube
from marltoolbox.envs.vectorized_coin_game import AsymVectorizedCoinGame
-from marltoolbox.utils.full_epi_logger import FullEpisodeLogger
+from marltoolbox.utils.log.full_epi_logger import FullEpisodeLogger
PLOT_KEYS = [
"player_1_loss",
@@ -141,7 +141,7 @@ def _init_lola(
grid_size,
gamma,
hidden,
- bs_mul,
+ global_lr_divider,
lr,
env_config,
mem_efficient=True,
@@ -168,6 +168,14 @@ def _init_lola(
use_destabilizer=False,
adding_scaled_weights=False,
always_train_PG=False,
+ use_normalized_rewards=False,
+ use_centered_reward=False,
+ use_rolling_avg_actor_grad=False,
+ process_reward_after_rolling=False,
+ only_process_reward=False,
+ use_rolling_avg_reward=False,
+ reward_processing_bais=False,
+ center_and_normalize_with_rolling_avg=False,
**kwargs,
):
@@ -194,7 +202,6 @@ def _init_lola(
self.grid_size = grid_size
self.gamma = gamma
self.hidden = hidden
- self.bs_mul = bs_mul
self.lr = lr
self.mem_efficient = mem_efficient
self.asymmetry = env_class == AsymVectorizedCoinGame
@@ -232,10 +239,47 @@ def _init_lola(
assert self.adding_scaled_weights > 0.0
self.always_train_PG = always_train_PG
self.last_term_to_use = 0.0
+ self.use_normalized_rewards = use_normalized_rewards
+ self.use_centered_reward = use_centered_reward
+ self.use_rolling_avg_actor_grad = use_rolling_avg_actor_grad
+ self.process_reward_after_rolling = process_reward_after_rolling
+ self.only_process_reward = only_process_reward
+ self.use_rolling_avg_reward = use_rolling_avg_reward
+ self.reward_processing_bais = reward_processing_bais
+ self.center_and_normalize_with_rolling_avg = (
+ center_and_normalize_with_rolling_avg
+ )
+ self._reward_rolling_avg = {}
+ self._reward_rolling_avg_length = 100
+
+ if self.use_rolling_avg_reward and self.use_rolling_avg_actor_grad:
+ self.use_rolling_avg_actor_grad = False
+ print("self.use_normalized_rewards", self.use_normalized_rewards)
+ print("self.use_centered_reward", self.use_centered_reward)
+ print(
+ "self.use_rolling_avg_actor_grad", self.use_rolling_avg_actor_grad
+ )
+
+ global_lr_divider_mul = 1.0
+ if self.use_normalized_rewards:
+ global_lr_divider_mul *= 2.0
+ if self.use_centered_reward:
+ global_lr_divider_mul *= 1.0
+ if self.use_rolling_avg_actor_grad:
+ global_lr_divider_mul *= 1 / 20.0
+ if self.use_rolling_avg_reward:
+ global_lr_divider_mul *= 1 / 10000.0
+
+ self.global_lr_divider = global_lr_divider * global_lr_divider_mul
+ self.global_lr_divider_original = (
+ global_lr_divider * global_lr_divider_mul
+ )
self.obs_batch = deque(maxlen=self.batch_size)
self.full_episode_logger = FullEpisodeLogger(
- logdir=self._logdir, log_interval=100, log_ful_epi_one_hot_obs=True
+ logdir=self._logdir,
+ log_interval=100,
+ convert_one_hot_obs_to_idx=True,
)
# Setting the training parameters
@@ -249,6 +293,10 @@ def _init_lola(
self.max_epLength = (
trace_length + 1
) # The max allowed length of our episode.
+ self.actor_rolling_avg = [
+ 10.0 / global_lr_divider_mul
+ for agent in range(self.total_n_agents)
+ ]
graph = tf.Graph()
@@ -301,69 +349,23 @@ def _init_lola(
use_critic=use_critic,
)
)
- # Clones of the opponents
- if opp_model:
- self.mainPN_clone = []
- for agent in range(self.total_n_agents):
- self.mainPN_clone.append(
- Pnetwork(
- f"clone_{agent}",
- self.h_size[agent],
- agent,
- self.env,
- trace_length=trace_length,
- batch_size=batch_size,
- changed_config=changed_config,
- ac_lr=ac_lr,
- use_MAE=use_MAE,
- clip_loss_norm=clip_loss_norm,
- sess=self.sess,
- entropy_coeff=entropy_coeff,
- use_critic=use_critic,
- )
- )
if not mem_efficient:
self.cube, self.cube_ops = make_cube(trace_length)
else:
self.cube, self.cube_ops = None, None
- if not opp_model:
- corrections_func(
- self.mainPN,
- batch_size,
- trace_length,
- corrections,
- self.cube,
- clip_lola_update_norm=clip_lola_update_norm,
- lola_correction_multiplier=self.lola_correction_multiplier,
- clip_lola_correction_norm=clip_lola_correction_norm,
- clip_lola_actor_norm=clip_lola_actor_norm,
- )
- else:
- corrections_func(
- [self.mainPN[0], self.mainPN_clone[1]],
- batch_size,
- trace_length,
- corrections,
- self.cube,
- clip_lola_update_norm=clip_lola_update_norm,
- lola_correction_multiplier=self.lola_correction_multiplier,
- clip_lola_correction_norm=clip_lola_correction_norm,
- clip_lola_actor_norm=clip_lola_actor_norm,
- )
- corrections_func(
- [self.mainPN[1], self.mainPN_clone[0]],
- batch_size,
- trace_length,
- corrections,
- self.cube,
- clip_lola_update_norm=clip_lola_update_norm,
- lola_correction_multiplier=self.lola_correction_multiplier,
- clip_lola_correction_norm=clip_lola_correction_norm,
- clip_lola_actor_norm=clip_lola_actor_norm,
- )
- clone_update(self.mainPN_clone)
+ corrections_func(
+ self.mainPN,
+ batch_size,
+ trace_length,
+ corrections,
+ self.cube,
+ clip_lola_update_norm=clip_lola_update_norm,
+ lola_correction_multiplier=self.lola_correction_multiplier,
+ clip_lola_correction_norm=clip_lola_correction_norm,
+ clip_lola_actor_norm=clip_lola_actor_norm,
+ )
if self.use_PG_exploiter:
simple_actor_training_func(
@@ -553,19 +555,24 @@ def step(self):
pow_series = np.arange(self.trace_length)
discount = np.array([pow(self.gamma, item) for item in pow_series])
+ reward_mean_p0 = np.array(trainBatch0[2]).mean()
+ reward_mean_p1 = np.array(trainBatch1[2]).mean()
+ reward_std_p0 = np.array(trainBatch0[2]).std()
+ reward_std_p1 = np.array(trainBatch1[2]).std()
+
(
sample_return0,
sample_reward0,
sample_reward0_bis,
) = self.compute_centered_discounted_r(
- rewards=trainBatch0[2], discount=discount
+ rewards=trainBatch0[2], discount=discount, player_key="player_0"
)
(
sample_return1,
sample_reward1,
sample_reward1_bis,
) = self.compute_centered_discounted_r(
- rewards=trainBatch1[2], discount=discount
+ rewards=trainBatch1[2], discount=discount, player_key="player_1"
)
state_input0 = np.concatenate(trainBatch0[0], axis=0)
@@ -607,34 +614,6 @@ def step(self):
},
)
- if self.opp_model:
- ## update local clones
- update_clone = [
- self.mainPN_clone[0].update,
- self.mainPN_clone[1].update,
- ]
- feed_dict = {
- self.mainPN_clone[0].state_input: state_input1,
- self.mainPN_clone[0].actions: actions1,
- self.mainPN_clone[0].sample_return: sample_return1,
- self.mainPN_clone[0].sample_reward: sample_reward1,
- self.mainPN_clone[1].state_input: state_input0,
- self.mainPN_clone[1].actions: actions0,
- self.mainPN_clone[1].sample_return: sample_return0,
- self.mainPN_clone[1].sample_reward: sample_reward0,
- self.mainPN_clone[0].gamma_array: np.reshape(
- discount, [1, -1]
- ),
- self.mainPN_clone[1].gamma_array: np.reshape(
- discount, [1, -1]
- ),
- self.mainPN_clone[0].is_training: True,
- self.mainPN_clone[1].is_training: True,
- }
- num_loops = 50 if self.timestep == 0 else 1
- for _ in range(num_loops):
- self.sess.run(update_clone, feed_dict=feed_dict)
-
if self.lr_decay:
lr_decay = (self.num_episodes - self.timestep) / self.num_episodes
else:
@@ -667,25 +646,6 @@ def step(self):
self.mainPN[0].is_training: True,
self.mainPN[1].is_training: True,
}
- if self.opp_model:
- feed_dict.update(
- {
- self.mainPN_clone[0].state_input: state_input1,
- self.mainPN_clone[0].actions: actions1,
- self.mainPN_clone[0].sample_return: sample_return1,
- self.mainPN_clone[0].sample_reward: sample_reward1,
- self.mainPN_clone[1].state_input: state_input0,
- self.mainPN_clone[1].actions: actions0,
- self.mainPN_clone[1].sample_return: sample_return0,
- self.mainPN_clone[1].sample_reward: sample_reward0,
- self.mainPN_clone[0].gamma_array: np.reshape(
- discount, [1, -1]
- ),
- self.mainPN_clone[1].gamma_array: np.reshape(
- discount, [1, -1]
- ),
- }
- )
lola_training_list = [
self.mainPN[0].value,
@@ -728,103 +688,46 @@ def step(self):
)
( # Player_red
- values,
+ values_p1,
updateModel_1,
update1,
player_1_value,
player_1_target,
player_1_loss,
- entropy_p_0,
- v_0_log,
- actor_target_error_0,
- actor_loss_0,
- parameters_norm_0,
- second_order0,
- v_0_grad_theta_0,
- second_order0_sum,
- actor_grad_sum_0,
+ entropy_p1,
+ v_0_log_p1,
+ actor_target_error_p1,
+ actor_loss_p1,
+ parameters_norm_p1,
+ second_order_p1,
+ v_0_grad_theta_p1,
+ second_order_sum_p1,
+ actor_grad_sum_p1,
v_0_grad_01,
- multiply0,
+ multiply_p1,
# Player_blue
- values_1,
- updateModel_2,
- update2,
+ values_p2,
+ updateModel_p2,
+ update_p2,
player_2_value,
player_2_target,
player_2_loss,
- entropy_p_1,
- v_1_log,
- actor_target_error_1,
- actor_loss_1,
- parameters_norm_1,
- second_order1,
- v_1_grad_theta_1,
- second_order1_sum,
- actor_grad_sum_1,
+ entropy_p2,
+ v_1_log_p2,
+ actor_target_error_p2,
+ actor_loss_p2,
+ parameters_norm_p2,
+ second_order_p2,
+ v_1_grad_theta_p2,
+ second_order_sum_p2,
+ actor_grad_sum_p2,
) = self.sess.run(lola_training_list, feed_dict=feed_dict)
- if self.warmup:
- update1 = update1 * self.warmup_step_n / self.warmup
- update2 = update2 * self.warmup_step_n / self.warmup
- if self.lr_decay:
- update1 = update1 * lr_decay
- update2 = update2 * lr_decay
-
- update1_sum = sum(update1) / self.bs_mul
- update2_sum = sum(update2) / self.bs_mul
-
- update(
- self.mainPN,
- self.lr,
- update1 / self.bs_mul,
- update2 / self.bs_mul,
- use_actions_from_exploiter,
+ update1_sum, update2_sum = self._update_players(
+ update1, update_p2, lr_decay, use_actions_from_exploiter
)
if self.use_PG_exploiter:
- # Update policy networks
- feed_dict = {
- self.mainPN[0].state_input: state_input0,
- self.mainPN[0].sample_return: sample_return0,
- self.mainPN[0].actions: actions0,
- self.mainPN[2].state_input: state_input1,
- self.mainPN[2].sample_return: sample_return1,
- self.mainPN[2].actions: actions1,
- self.mainPN[0].sample_reward: sample_reward0,
- self.mainPN[2].sample_reward: sample_reward1,
- self.mainPN[0].sample_reward_bis: sample_reward0_bis,
- self.mainPN[2].sample_reward_bis: sample_reward1_bis,
- self.mainPN[0].gamma_array: np.reshape(discount, [1, -1]),
- self.mainPN[2].gamma_array: np.reshape(discount, [1, -1]),
- self.mainPN[0].next_value: value_0_next,
- self.mainPN[2].next_value: expl_value_next,
- self.mainPN[0].gamma_array_inverse: np.reshape(
- self.discount_array, [1, -1]
- ),
- self.mainPN[2].gamma_array_inverse: np.reshape(
- self.discount_array, [1, -1]
- ),
- self.mainPN[0].loss_multiplier: [lr_decay],
- self.mainPN[2].loss_multiplier: [lr_decay],
- self.mainPN[0].is_training: True,
- self.mainPN[2].is_training: True,
- }
-
- lola_training_list = [
- self.mainPN[2].value,
- self.mainPN[2].updateModel,
- self.mainPN[2].delta,
- self.mainPN[2].value,
- self.mainPN[2].target,
- self.mainPN[2].loss,
- self.mainPN[2].entropy,
- self.mainPN[2].v_0_log,
- self.mainPN[2].actor_target_error,
- self.mainPN[2].actor_loss,
- self.mainPN[2].weigths_norm,
- self.mainPN[2].grad,
- self.mainPN[2].grad_sum,
- ]
(
pg_expl_values,
pg_expl_updateModel,
@@ -839,32 +742,405 @@ def step(self):
pg_expl_parameters_norm,
pg_expl_v_grad_theta,
pg_expl_actor_grad_sum,
- ) = self.sess.run(lola_training_list, feed_dict=feed_dict)
+ pg_expl_update_sum,
+ ) = self._update_pg_exploiter(
+ state_input0,
+ sample_return0,
+ actions0,
+ state_input1,
+ sample_return1,
+ actions1,
+ sample_reward0,
+ sample_reward1,
+ sample_reward0_bis,
+ sample_reward1_bis,
+ discount,
+ value_0_next,
+ expl_value_next,
+ lr_decay,
+ )
+ else:
+ pg_expl_player_loss = None
+ pg_expl_v_log = None
+ pg_expl_entropy = None
+ pg_expl_actor_loss = None
+ pg_expl_parameters_norm = None
+ pg_expl_update_sum = None
+ pg_expl_actor_grad_sum = None
+
+ self._update_actor_rolling_avg(
+ actor_grad_sum_p1,
+ actor_grad_sum_p2,
+ trainBatch0[2],
+ trainBatch1[2],
+ )
- if self.warmup:
- pg_expl_update = (
- pg_expl_update * self.warmup_step_n / self.warmup
+ print("update params")
+ to_report = self._prepare_logs(
+ to_report,
+ last_info,
+ player_1_loss,
+ player_2_loss,
+ v_0_log_p1,
+ v_1_log_p2,
+ entropy_p1,
+ entropy_p2,
+ actor_loss_p1,
+ actor_loss_p2,
+ parameters_norm_p1,
+ parameters_norm_p2,
+ second_order_sum_p1,
+ second_order_sum_p2,
+ update1_sum,
+ update2_sum,
+ actor_grad_sum_p1,
+ actor_grad_sum_p2,
+ pg_expl_player_loss,
+ pg_expl_v_log,
+ pg_expl_entropy,
+ pg_expl_actor_loss,
+ pg_expl_parameters_norm,
+ pg_expl_update_sum,
+ pg_expl_actor_grad_sum,
+ reward_mean_p0,
+ reward_mean_p1,
+ reward_std_p0,
+ reward_std_p1,
+ sample_return0,
+ sample_return1,
+ sample_reward0,
+ sample_reward1,
+ values_p1,
+ values_p2,
+ player_1_target,
+ player_2_target,
+ )
+ return to_report
+
+ def compute_centered_discounted_r(self, rewards, discount, player_key=""):
+ if not self.only_process_reward:
+ rewards = self._process_rewards(
+ rewards, preprocess=True, key=player_key + "_rewards"
+ )
+
+ sample_return = np.reshape(
+ get_monte_carlo(
+ rewards, self.y, self.trace_length, self.batch_size
+ ),
+ [self.batch_size, -1],
+ )
+
+ if self.only_process_reward:
+ rewards = self._process_rewards(
+ rewards, preprocess=True, key=player_key + "_rewards"
+ )
+
+ if self.correction_reward_baseline_per_step:
+ sample_reward = discount * np.reshape(
+ rewards - np.mean(np.array(rewards), axis=0),
+ [-1, self.trace_length],
+ )
+ else:
+ sample_reward = discount * np.reshape(
+ rewards - np.mean(rewards), [-1, self.trace_length]
+ )
+ sample_reward_bis = discount * np.reshape(
+ rewards, [-1, self.trace_length]
+ )
+
+ if not self.only_process_reward:
+ sample_return = self._process_rewards(
+ sample_return,
+ preprocess=False,
+ key=player_key + "_sample_return",
+ )
+ sample_reward = self._process_rewards(
+ sample_reward, preprocess=False, key=player_key + "_sample_reward"
+ )
+ sample_reward_bis = self._process_rewards(
+ sample_reward_bis,
+ preprocess=False,
+ key=player_key + "_sample_reward_bis",
+ )
+ return sample_return, sample_reward, sample_reward_bis
+
+ def _process_rewards(self, rewards, preprocess, key: str):
+ postprocess = not preprocess
+ preprocessing = preprocess and not self.process_reward_after_rolling
+ postprocessing = postprocess and self.process_reward_after_rolling
+ if preprocessing or postprocessing:
+ if self.use_normalized_rewards:
+ return self._normalize_rewards(rewards, key)
+ elif self.use_centered_reward:
+ return self._center_reward(rewards, key)
+ return rewards
+
+ def _normalize_rewards(self, rewards, key: str):
+ r_array = np.array(rewards)
+ if self.center_and_normalize_with_rolling_avg:
+ mean_key = key + "_mean"
+ std_key = key + "_std"
+ self._update_reward_rolling_avg(r_array.mean(), mean_key)
+ self._update_reward_rolling_avg(r_array.std(), std_key)
+ rewards = (
+ r_array - self._get_reward_rolling_avg(mean_key)
+ ) / self._get_reward_rolling_avg(std_key)
+ else:
+ rewards = (r_array - r_array.mean()) / r_array.std()
+ if self.reward_processing_bais:
+ rewards += self.reward_processing_bais
+ return rewards.tolist()
+
+ def _center_reward(self, rewards, key: str):
+ r_array = np.array(rewards)
+ if self.center_and_normalize_with_rolling_avg:
+ self._update_reward_rolling_avg(r_array.mean(), key)
+ rewards = r_array - self._get_reward_rolling_avg(key)
+ else:
+ rewards = r_array - r_array.mean()
+ if self.reward_processing_bais:
+ rewards += self.reward_processing_bais
+ return rewards.tolist()
+
+ def _update_reward_rolling_avg(self, value, key):
+ if key in self._reward_rolling_avg:
+ self._reward_rolling_avg[key] = value + self._reward_rolling_avg[
+ key
+ ] * (1 - (1 / self._reward_rolling_avg_length))
+ else:
+ self._reward_rolling_avg[key] = (
+ value * self._reward_rolling_avg_length
+ )
+
+ def _get_reward_rolling_avg(self, key):
+ return self._reward_rolling_avg[key] / self._reward_rolling_avg_length
+
+ def _update_bs_mul_wt_rolling_avg(self):
+ rolling_avg = None
+ if self.use_rolling_avg_actor_grad:
+ rolling_avg = self.use_rolling_avg_actor_grad
+ if self.use_rolling_avg_reward:
+ rolling_avg = self.use_rolling_avg_reward
+
+ if rolling_avg is not None:
+ self.global_lr_divider = [
+ self.global_lr_divider_original * v / rolling_avg
+ for v in self.actor_rolling_avg
+ ]
+
+ def _update_actor_rolling_avg(
+ self,
+ actor_grad_sum_0,
+ actor_grad_sum_1,
+ reward_player_0,
+ reward_player_1,
+ ):
+ if self.use_rolling_avg_actor_grad or self.use_rolling_avg_reward:
+ roll_avg_p0 = self.actor_rolling_avg[0]
+ roll_avg_p1 = self.actor_rolling_avg[1]
+ if self.use_rolling_avg_actor_grad:
+ roll_avg_p0 = np.abs(actor_grad_sum_0) + roll_avg_p0 * (
+ 1 - (1 / self.use_rolling_avg_actor_grad)
)
- if self.lr_decay:
- pg_expl_update = pg_expl_update * lr_decay
+ roll_avg_p1 = np.abs(actor_grad_sum_1) + roll_avg_p1 * (
+ 1 - (1 / self.use_rolling_avg_actor_grad)
+ )
+ if self.use_rolling_avg_reward:
+ roll_avg_p0 = np.sum(np.abs(reward_player_0)) + roll_avg_p0 * (
+ 1 - (1 / self.use_rolling_avg_reward)
+ )
+ roll_avg_p1 = np.sum(np.abs(reward_player_1)) + roll_avg_p1 * (
+ 1 - (1 / self.use_rolling_avg_reward)
+ )
+ self.actor_rolling_avg[0] = roll_avg_p0
+ self.actor_rolling_avg[1] = roll_avg_p1
- # pg_expl_update_to_log = pg_expl_update
- pg_expl_update_sum = sum(pg_expl_update) / self.bs_mul
+ def _update_players(
+ self, update1, update2, lr_decay, use_actions_from_exploiter
+ ):
+ if self.warmup:
+ update1 = update1 * self.warmup_step_n / self.warmup
+ update2 = update2 * self.warmup_step_n / self.warmup
+ if self.lr_decay:
+ update1 = update1 * lr_decay
+ update2 = update2 * lr_decay
- update_single(
- self.mainPN[2], self.lr, pg_expl_update / self.bs_mul
+ if self.use_rolling_avg_actor_grad or self.use_rolling_avg_reward:
+ self._update_bs_mul_wt_rolling_avg()
+ update1_sum = sum(update1) / self.global_lr_divider[0]
+ update2_sum = sum(update2) / self.global_lr_divider[1]
+ update(
+ self.mainPN,
+ self.lr,
+ update1 / self.global_lr_divider[0],
+ update2 / self.global_lr_divider[1],
+ use_actions_from_exploiter,
)
+ else:
+ update1_sum = sum(update1) / self.global_lr_divider
+ update2_sum = sum(update2) / self.global_lr_divider
+ update(
+ self.mainPN,
+ self.lr,
+ update1 / self.global_lr_divider,
+ update2 / self.global_lr_divider,
+ use_actions_from_exploiter,
+ )
+ return update1_sum, update2_sum
- if self.timestep >= self.start_using_exploiter_at_update_n:
- if self.timestep % self.every_n_updates_copy_weights == 0:
- copy_weigths(
- from_policy=self.mainPN[2],
- to_policy=self.mainPN[1],
- adding_scaled_weights=self.adding_scaled_weights,
- )
+ def _update_pg_exploiter(
+ self,
+ state_input0,
+ sample_return0,
+ actions0,
+ state_input1,
+ sample_return1,
+ actions1,
+ sample_reward0,
+ sample_reward1,
+ sample_reward0_bis,
+ sample_reward1_bis,
+ discount,
+ value_0_next,
+ expl_value_next,
+ lr_decay,
+ ):
+ # Update policy networks
+ feed_dict = {
+ self.mainPN[0].state_input: state_input0,
+ self.mainPN[0].sample_return: sample_return0,
+ self.mainPN[0].actions: actions0,
+ self.mainPN[2].state_input: state_input1,
+ self.mainPN[2].sample_return: sample_return1,
+ self.mainPN[2].actions: actions1,
+ self.mainPN[0].sample_reward: sample_reward0,
+ self.mainPN[2].sample_reward: sample_reward1,
+ self.mainPN[0].sample_reward_bis: sample_reward0_bis,
+ self.mainPN[2].sample_reward_bis: sample_reward1_bis,
+ self.mainPN[0].gamma_array: np.reshape(discount, [1, -1]),
+ self.mainPN[2].gamma_array: np.reshape(discount, [1, -1]),
+ self.mainPN[0].next_value: value_0_next,
+ self.mainPN[2].next_value: expl_value_next,
+ self.mainPN[0].gamma_array_inverse: np.reshape(
+ self.discount_array, [1, -1]
+ ),
+ self.mainPN[2].gamma_array_inverse: np.reshape(
+ self.discount_array, [1, -1]
+ ),
+ self.mainPN[0].loss_multiplier: [lr_decay],
+ self.mainPN[2].loss_multiplier: [lr_decay],
+ self.mainPN[0].is_training: True,
+ self.mainPN[2].is_training: True,
+ }
- print("update params")
+ lola_training_list = [
+ self.mainPN[2].value,
+ self.mainPN[2].updateModel,
+ self.mainPN[2].delta,
+ self.mainPN[2].value,
+ self.mainPN[2].target,
+ self.mainPN[2].loss,
+ self.mainPN[2].entropy,
+ self.mainPN[2].v_0_log,
+ self.mainPN[2].actor_target_error,
+ self.mainPN[2].actor_loss,
+ self.mainPN[2].weigths_norm,
+ self.mainPN[2].grad,
+ self.mainPN[2].grad_sum,
+ ]
+ (
+ pg_expl_values,
+ pg_expl_updateModel,
+ pg_expl_update,
+ pg_expl_player_value,
+ pg_expl_player_target,
+ pg_expl_player_loss,
+ pg_expl_entropy,
+ pg_expl_v_log,
+ pg_expl_actor_target_error,
+ pg_expl_actor_loss,
+ pg_expl_parameters_norm,
+ pg_expl_v_grad_theta,
+ pg_expl_actor_grad_sum,
+ ) = self.sess.run(lola_training_list, feed_dict=feed_dict)
+
+ if self.warmup:
+ pg_expl_update = pg_expl_update * self.warmup_step_n / self.warmup
+ if self.lr_decay:
+ pg_expl_update = pg_expl_update * lr_decay
+ # pg_expl_update_to_log = pg_expl_update
+ pg_expl_update_sum = sum(pg_expl_update) / self.global_lr_divider
+
+ update_single(
+ self.mainPN[2], self.lr, pg_expl_update / self.global_lr_divider
+ )
+
+ if self.timestep >= self.start_using_exploiter_at_update_n:
+ if self.timestep % self.every_n_updates_copy_weights == 0:
+ copy_weigths(
+ from_policy=self.mainPN[2],
+ to_policy=self.mainPN[1],
+ adding_scaled_weights=self.adding_scaled_weights,
+ )
+
+ return (
+ pg_expl_values,
+ pg_expl_updateModel,
+ pg_expl_update,
+ pg_expl_player_value,
+ pg_expl_player_target,
+ pg_expl_player_loss,
+ pg_expl_entropy,
+ pg_expl_v_log,
+ pg_expl_actor_target_error,
+ pg_expl_actor_loss,
+ pg_expl_parameters_norm,
+ pg_expl_v_grad_theta,
+ pg_expl_actor_grad_sum,
+ pg_expl_update_sum,
+ )
+
+ def _prepare_logs(
+ self,
+ to_report,
+ last_info,
+ player_1_loss,
+ player_2_loss,
+ v_0_log,
+ v_1_log,
+ entropy_p_0,
+ entropy_p_1,
+ actor_loss_0,
+ actor_loss_1,
+ parameters_norm_0,
+ parameters_norm_1,
+ second_order0_sum,
+ second_order1_sum,
+ update1_sum,
+ update2_sum,
+ actor_grad_sum_0,
+ actor_grad_sum_1,
+ pg_expl_player_loss,
+ pg_expl_v_log,
+ pg_expl_entropy,
+ pg_expl_actor_loss,
+ pg_expl_parameters_norm,
+ pg_expl_update_sum,
+ pg_expl_actor_grad_sum,
+ reward_mean_p0,
+ reward_mean_p1,
+ reward_std_p0,
+ reward_std_p1,
+ sample_return0,
+ sample_return1,
+ sample_reward0,
+ sample_reward1,
+ values,
+ values_1,
+ player_1_target,
+ player_2_target,
+ ):
rlog = np.sum(self.rList[-self.summary_len :], 0)
to_plot = {}
@@ -891,6 +1167,10 @@ def step(self):
last_info.pop("available_actions", None)
training_info = {
+ "reward_mean_p0": reward_mean_p0,
+ "reward_mean_p1": reward_mean_p1,
+ "reward_std_p0": reward_std_p0,
+ "reward_std_p1": reward_std_p1,
"player_1_loss": player_1_loss,
"player_2_loss": player_2_loss,
"v_0_log": v_0_log,
@@ -911,16 +1191,18 @@ def step(self):
/ self.num_episodes,
}
# Logging distribution (can be a speed bottleneck)
- # training_info.update({
- # "sample_return0": sample_return0,
- # "sample_return1": sample_return1,
- # "sample_reward0": sample_reward0,
- # "sample_reward1": sample_reward1,
- # "player_1_values": values,
- # "player_2_values": values_1,
- # "player_1_target": player_1_target,
- # "player_2_target": player_2_target,
- # })
+ # training_info.update(
+ # {
+ # "sample_return0": sample_return0,
+ # "sample_return1": sample_return1,
+ # "sample_reward0": sample_reward0,
+ # "sample_reward1": sample_reward1,
+ # "player_1_values": values,
+ # "player_2_values": values_1,
+ # "player_1_target": player_1_target,
+ # "player_2_target": player_2_target,
+ # }
+ # )
# self.update1_list.clear()
# self.update2_list.clear()
@@ -952,30 +1234,14 @@ def step(self):
+ to_report.get("player_red_pick_own_color", 0.0)
)
+ if self.use_rolling_avg_actor_grad:
+ to_report["rolling_avg_actor_grad_p0"] = self.actor_rolling_avg[0]
+ to_report["rolling_avg_actor_grad_p1"] = self.actor_rolling_avg[1]
+ if self.use_rolling_avg_reward:
+ to_report["rolling_avg_reward_p0"] = self.actor_rolling_avg[0]
+ to_report["rolling_avg_reward_p1"] = self.actor_rolling_avg[1]
return to_report
- def compute_centered_discounted_r(self, rewards, discount):
- sample_return = np.reshape(
- get_monte_carlo(
- rewards, self.y, self.trace_length, self.batch_size
- ),
- [self.batch_size, -1],
- )
-
- if self.correction_reward_baseline_per_step:
- sample_reward = discount * np.reshape(
- rewards - np.mean(np.array(rewards), axis=0),
- [-1, self.trace_length],
- )
- else:
- sample_reward = discount * np.reshape(
- rewards - np.mean(rewards), [-1, self.trace_length]
- )
- sample_reward_bis = discount * np.reshape(
- rewards, [-1, self.trace_length]
- )
- return sample_return, sample_reward, sample_reward_bis
-
def _log_one_step_in_full_episode(self, s, r, actions, obs, info):
self.full_episode_logger.on_episode_step(
step_data={
diff --git a/marltoolbox/algos/lola/train_exact_tune_class_API.py b/marltoolbox/algos/lola/train_exact_tune_class_API.py
index ef844ab..472c00a 100644
--- a/marltoolbox/algos/lola/train_exact_tune_class_API.py
+++ b/marltoolbox/algos/lola/train_exact_tune_class_API.py
@@ -11,7 +11,7 @@
import json
import os
import random
-
+import torch
import numpy as np
import tensorflow as tf
import tensorflow.contrib.layers as layers
@@ -29,22 +29,35 @@ class Qnetwork:
Q-network that is either a look-up table or an MLP with 1 hidden layer.
"""
- def __init__(self, myScope, num_hidden, sess, simple_net=True):
+ def __init__(
+ self, myScope, num_hidden, sess, simple_net=True, n_actions=2, std=1.0
+ ):
with tf.variable_scope(myScope):
- self.input_place = tf.placeholder(shape=[5], dtype=tf.int32)
+ # self.input_place = tf.placeholder(shape=[5], dtype=tf.int32)
+ # if simple_net:
+ # self.p_act = tf.Variable(tf.random_normal([5, 1], stddev=3.0))
+ # else:
+ # act = tf.nn.tanh(
+ # layers.fully_connected(
+ # tf.one_hot(self.input_place, 5, dtype=tf.float32),
+ # num_outputs=num_hidden,
+ # activation_fn=None,
+ # )
+ # )
+ # self.p_act = layers.fully_connected(
+ # act, num_outputs=1, activation_fn=None
+ # )
+ self.input_place = tf.placeholder(
+ shape=[n_actions ** 2 + 1], dtype=tf.int32
+ )
if simple_net:
- self.p_act = tf.Variable(tf.random_normal([5, 1], stddev=1.0))
- else:
- act = tf.nn.tanh(
- layers.fully_connected(
- tf.one_hot(self.input_place, 5, dtype=tf.float32),
- num_outputs=num_hidden,
- activation_fn=None,
+ self.p_act = tf.Variable(
+ tf.random_normal(
+ [n_actions ** 2 + 1, n_actions], stddev=std
)
)
- self.p_act = layers.fully_connected(
- act, num_outputs=1, activation_fn=None
- )
+ else:
+ raise ValueError()
self.parameters = []
for i in tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope=myScope
@@ -63,51 +76,135 @@ def update(mainQN, lr, final_delta_1_v, final_delta_2_v):
)
-def corrections_func(mainQN, corrections, gamma, pseudo, reg):
+def corrections_func(mainQN, corrections, gamma, pseudo, reg, n_actions=2):
+ print_opts = []
+
mainQN[0].lr_correction = tf.placeholder(shape=[1], dtype=tf.float32)
mainQN[1].lr_correction = tf.placeholder(shape=[1], dtype=tf.float32)
theta_1_all = mainQN[0].p_act
theta_2_all = mainQN[1].p_act
- theta_1 = tf.slice(theta_1_all, [0, 0], [4, 1])
- theta_2 = tf.slice(theta_2_all, [0, 0], [4, 1])
-
- theta_1_0 = tf.slice(theta_1_all, [4, 0], [1, 1])
- theta_2_0 = tf.slice(theta_2_all, [4, 0], [1, 1])
- p_1 = tf.nn.sigmoid(theta_1)
- p_2 = tf.nn.sigmoid(theta_2)
- mainQN[0].policy = tf.nn.sigmoid(theta_1_all)
- mainQN[1].policy = tf.nn.sigmoid(theta_2_all)
-
- p_1_0 = tf.nn.sigmoid(theta_1_0)
- p_2_0 = tf.nn.sigmoid(theta_2_0)
+ # Using sigmoid + normalize to keep values similar to the official
+ # implementation
+ pi_player_1_for_all_states = tf.nn.sigmoid(theta_1_all)
+ pi_player_2_for_all_states = tf.nn.sigmoid(theta_2_all)
+ print_opts.append(
+ tf.print("pi_player_1_for_all_states", pi_player_1_for_all_states)
+ )
+ print_opts.append(
+ tf.print("pi_player_2_for_all_states", pi_player_2_for_all_states)
+ )
+ sum_1 = tf.reduce_sum(pi_player_1_for_all_states, axis=1)
+ sum_1 = tf.stack([sum_1 for _ in range(n_actions)], axis=1)
+ sum_2 = tf.reduce_sum(pi_player_2_for_all_states, axis=1)
+ sum_2 = tf.stack([sum_2 for _ in range(n_actions)], axis=1)
+ pi_player_1_for_all_states = pi_player_1_for_all_states / sum_1
+ pi_player_2_for_all_states = pi_player_2_for_all_states / sum_2
+ print_opts.append(
+ tf.print("pi_player_1_for_all_states", pi_player_1_for_all_states)
+ )
+ mainQN[0].policy = pi_player_1_for_all_states
+ mainQN[1].policy = pi_player_2_for_all_states
+ n_states = int(pi_player_1_for_all_states.shape[0])
+ print_opts.append(tf.print("mainQN[0].policy", mainQN[0].policy))
+ print_opts.append(tf.print("mainQN[1].policy", mainQN[1].policy))
+ pi_player_1_for_states_in_game = tf.slice(
+ pi_player_1_for_all_states,
+ [0, 0],
+ [n_states - 1, n_actions],
+ )
+ pi_player_2_for_states_in_game = tf.slice(
+ pi_player_2_for_all_states,
+ [0, 0],
+ [n_states - 1, n_actions],
+ )
+ pi_player_1_for_initial_state = tf.slice(
+ pi_player_1_for_all_states,
+ [n_states - 1, 0],
+ [1, n_actions],
+ )
+ pi_player_2_for_initial_state = tf.slice(
+ pi_player_2_for_all_states,
+ [n_states - 1, 0],
+ [1, n_actions],
+ )
+ pi_player_1_for_initial_state = tf.transpose(pi_player_1_for_initial_state)
+ pi_player_2_for_initial_state = tf.transpose(pi_player_2_for_initial_state)
+ print_opts.append(
+ tf.print(
+ "pi_player_1_for_initial_state", pi_player_1_for_initial_state
+ )
+ )
- p_1_0_v = tf.concat([p_1_0, (1 - p_1_0)], 0)
- p_2_0_v = tf.concat([p_2_0, (1 - p_2_0)], 0)
+ s_0 = tf.reshape(
+ tf.matmul(
+ pi_player_1_for_initial_state,
+ tf.transpose(pi_player_2_for_initial_state),
+ ),
+ [-1, 1],
+ )
- s_0 = tf.reshape(tf.matmul(p_1_0_v, tf.transpose(p_2_0_v)), [-1, 1])
+ pi_p1 = tf.reshape(pi_player_1_for_states_in_game, [n_states - 1, -1, 1])
+ pi_p2 = tf.reshape(pi_player_2_for_states_in_game, [n_states - 1, -1, 1])
- # CC, CD, DC, DD
+ all_actions_proba_pairs = []
+ for action_p1 in range(n_actions):
+ for action_p2 in range(n_actions):
+ all_actions_proba_pairs.append(
+ tf.multiply(pi_p1[:, action_p1], pi_p2[:, action_p2])
+ )
P = tf.concat(
- [
- tf.multiply(p_1, p_2),
- tf.multiply(p_1, 1 - p_2),
- tf.multiply(1 - p_1, p_2),
- tf.multiply(1 - p_1, 1 - p_2),
- ],
+ all_actions_proba_pairs,
1,
)
- R_1 = tf.placeholder(shape=[4, 1], dtype=tf.float32)
- R_2 = tf.placeholder(shape=[4, 1], dtype=tf.float32)
-
- I_m_P = tf.diag([1.0, 1.0, 1.0, 1.0]) - P * gamma
+ # if n_actions == 2:
+ # # CC, CD, DC, DD
+ #
+ # P = tf.concat(
+ # [
+ # tf.multiply(pi_p1[:, 0], pi_p2[:, 0]),
+ # tf.multiply(pi_p1[:, 0], pi_p2[:, 1]),
+ # tf.multiply(pi_p1[:, 1], pi_p2[:, 0]),
+ # tf.multiply(pi_p1[:, 1], pi_p2[:, 1]),
+ # ],
+ # 1,
+ # )
+ # elif n_actions == 3:
+ # # CC, CD, CN, DC, DD, DN, NC, ND, NN
+ # P = tf.concat(
+ # [
+ # tf.multiply(pi_p1[:, 0], pi_p2[:, 0]),
+ # tf.multiply(pi_p1[:, 0], pi_p2[:, 1]),
+ # tf.multiply(pi_p1[:, 0], pi_p2[:, 2]),
+ # tf.multiply(pi_p1[:, 1], pi_p2[:, 0]),
+ # tf.multiply(pi_p1[:, 1], pi_p2[:, 1]),
+ # tf.multiply(pi_p1[:, 1], pi_p2[:, 2]),
+ # tf.multiply(pi_p1[:, 2], pi_p2[:, 0]),
+ # tf.multiply(pi_p1[:, 2], pi_p2[:, 1]),
+ # tf.multiply(pi_p1[:, 2], pi_p2[:, 2]),
+ # ],
+ # 1,
+ # )
+ # else:
+ # raise ValueError(f"n_actions {n_actions}")
+ # R_1 = tf.placeholder(shape=[4, 1], dtype=tf.float32)
+ # R_2 = tf.placeholder(shape=[4, 1], dtype=tf.float32)
+ R_1 = tf.placeholder(shape=[n_actions ** 2, 1], dtype=tf.float32)
+ R_2 = tf.placeholder(shape=[n_actions ** 2, 1], dtype=tf.float32)
+
+ # I_m_P = tf.diag([1.0, 1.0, 1.0, 1.0]) - P * gamma
+ I_m_P = tf.diag([1.0] * (n_actions ** 2)) - P * gamma
v_0 = tf.matmul(
tf.matmul(tf.matrix_inverse(I_m_P), R_1), s_0, transpose_a=True
)
v_1 = tf.matmul(
tf.matmul(tf.matrix_inverse(I_m_P), R_2), s_0, transpose_a=True
)
+ print_opts.append(tf.print("s_0", s_0))
+ print_opts.append(tf.print("I_m_P", I_m_P))
+ print_opts.append(tf.print("R_1", R_1))
+ print_opts.append(tf.print("v_0", v_0))
if reg > 0:
for indx, _ in enumerate(mainQN[0].parameters):
v_0 -= reg * tf.reduce_sum(
@@ -145,6 +242,7 @@ def corrections_func(mainQN, corrections, gamma, pseudo, reg):
tf.reshape(v_0_grad_theta_0_wrong, [param_len, 1]),
)
+ # with tf.control_dependencies(print_opts):
second_order0 = flatgrad(multiply0, mainQN[0].parameters)
second_order1 = flatgrad(multiply1, mainQN[1].parameters)
@@ -158,11 +256,11 @@ def corrections_func(mainQN, corrections, gamma, pseudo, reg):
mainQN[1].delta += tf.multiply(second_order1, mainQN[1].lr_correction)
-class LOLAExact(tune.Trainable):
+class LOLAExactTrainer(tune.Trainable):
def _init_lola(
self,
- env_name,
*,
+ env_name="IteratedAsymBoS",
num_episodes=50,
trace_length=200,
simple_net=True,
@@ -175,10 +273,12 @@ def _init_lola(
gamma=0.96,
with_linear_LR_decay_to_zero=False,
clip_update=None,
+ re_init_every_n_epi=1,
+ Q_net_std=1.0,
**kwargs,
):
- print("args not used:", kwargs)
+ # print("args not used:", kwargs)
self.num_episodes = num_episodes
self.trace_length = trace_length
@@ -192,6 +292,8 @@ def _init_lola(
self.gamma = gamma
self.with_linear_LR_decay_to_zero = with_linear_LR_decay_to_zero
self.clip_update = clip_update
+ self.re_init_every_n_epi = re_init_every_n_epi
+ self.Q_net_std = Q_net_std
graph = tf.Graph()
@@ -202,11 +304,19 @@ def _init_lola(
if env_name == "IPD":
self.payout_mat_1 = np.array([[-1.0, 0.0], [-3.0, -2.0]])
self.payout_mat_2 = self.payout_mat_1.T
- elif env_name == "AsymmetricIteratedBoS":
+ elif env_name == "IteratedAsymBoS":
self.payout_mat_1 = np.array([[+4.0, 0.0], [0.0, +2.0]])
self.payout_mat_2 = np.array([[+1.0, 0.0], [0.0, +2.0]])
+ elif "custom_payoff_matrix" in kwargs.keys():
+ custom_matrix = kwargs["custom_payoff_matrix"]
+ self.payout_mat_1 = custom_matrix[:, :, 0]
+ self.payout_mat_2 = custom_matrix[:, :, 1]
else:
raise ValueError(f"exp_name: {env_name}")
+ self.n_actions = int(self.payout_mat_1.shape[0])
+ assert self.n_actions == self.payout_mat_1.shape[1]
+ self.policy1 = [0.0] * self.n_actions
+ self.policy2 = [0.0] * self.n_actions
# Sanity
@@ -219,6 +329,8 @@ def _init_lola(
self.num_hidden,
self.sess,
self.simple_net,
+ self.n_actions,
+ self.Q_net_std,
)
)
@@ -229,6 +341,7 @@ def _init_lola(
self.gamma,
self.pseudo,
self.reg,
+ self.n_actions,
)
self.results = []
@@ -241,21 +354,22 @@ def _init_lola(
# TODO add something to not load and create everything when only evaluating with RLLib
def setup(self, config):
- print("_init_lola", config)
self._init_lola(**config)
def step(self):
- self.sess.run(self.init)
- lr_coor = np.ones(1) * self.lr_correction
-
log_items = {}
log_items["episode"] = self.training_iteration
+ if self.training_iteration % self.re_init_every_n_epi == 0:
+ self.sess.run(self.init)
+
+ lr_coor = np.ones(1) * self.lr_correction
+
res = []
params_time = []
delta_time = []
- input_vals = np.reshape(np.array(range(5)) + 1, [-1])
+ input_vals = self._get_input_vals()
for i in range(self.trace_length):
params0 = self.mainQN[0].getparams()
params1 = self.mainQN[1].getparams()
@@ -267,14 +381,13 @@ def step(self):
self.mainQN[0].policy,
self.mainQN[1].policy,
]
- # print("input_vals", input_vals)
- update1, update2, v1, v2, policy1, policy2 = self.sess.run(
+ (update1, update2, v1, v2, policy1, policy2) = self.sess.run(
outputs,
feed_dict={
self.mainQN[0].input_place: input_vals,
self.mainQN[1].input_place: input_vals,
- self.mainQN[0].R1: np.reshape(self.payout_mat_2, [-1, 1]),
- self.mainQN[1].R1: np.reshape(self.payout_mat_1, [-1, 1]),
+ self.mainQN[0].R1: np.reshape(self.payout_mat_1, [-1, 1]),
+ self.mainQN[1].R1: np.reshape(self.payout_mat_2, [-1, 1]),
self.mainQN[0].lr_correction: lr_coor,
self.mainQN[1].lr_correction: lr_coor,
},
@@ -289,10 +402,21 @@ def step(self):
res.append([v1[0][0] / self.norm, v2[0][0] / self.norm])
self.results.append(res)
+ if self.training_iteration % self.re_init_every_n_epi == (
+ self.re_init_every_n_epi - 1
+ ):
+ self.policy1 = policy1
+ self.policy2 = policy2
+
log_items["episodes_total"] = self.training_iteration
+ log_items["policy1"] = self.policy1
+ log_items["policy2"] = self.policy2
return log_items
+ def _get_input_vals(self):
+ return np.reshape(np.array(range(self.n_actions ** 2 + 1)) + 1, [-1])
+
def _clip_update(self, update1, update2):
if self.clip_update is not None:
assert self.clip_update > 0.0
@@ -382,24 +506,28 @@ def _post_process_action(self, action):
return action[None, ...] # add batch dim
def compute_actions(self, policy_id: str, obs_batch: list):
- # because of the LSTM
- assert len(obs_batch) == 1
+ assert (
+ len(obs_batch) == 1
+ ), f"{len(obs_batch)} == 1. obs_batch: {obs_batch}"
for single_obs in obs_batch:
agent_to_use = self._get_agent_to_use(policy_id)
obs = self._preprocess_obs(single_obs, agent_to_use)
- input_vals = np.reshape(np.array(range(5)) + 1, [-1])
+ input_vals = self._get_input_vals()
policy = self.sess.run(
- [self.mainQN[agent_to_use].policy],
+ [
+ self.mainQN[agent_to_use].policy,
+ ],
feed_dict={
self.mainQN[agent_to_use].input_place: input_vals,
},
)
- coop_proba = policy[0][obs][0]
- if coop_proba > random.random():
- action = np.array(0)
- else:
- action = np.array(1)
+ probabilities = policy[0][obs]
+ probabilities = torch.tensor(probabilities)
+ policy_for_this_state = torch.distributions.Categorical(
+ probs=probabilities
+ )
+ action = policy_for_this_state.sample()
action = self._post_process_action(action)
diff --git a/marltoolbox/algos/ltft/ltft.py b/marltoolbox/algos/ltft/ltft.py
index 5a91e15..203f7ad 100644
--- a/marltoolbox/algos/ltft/ltft.py
+++ b/marltoolbox/algos/ltft/ltft.py
@@ -12,13 +12,16 @@
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
-from ray.rllib.execution.train_ops import TrainOneStep
-from ray.rllib.execution.train_ops import UpdateTargetNetwork
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.schedules import PiecewiseSchedule
from ray.rllib.utils.typing import PolicyID, SampleBatchType
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
+from ray.rllib.execution.train_ops import (
+ TrainOneStep,
+ UpdateTargetNetwork,
+ TrainTFMultiGPU,
+)
from torch.nn import CrossEntropyLoss
from marltoolbox.algos import augmented_dqn, supervised_learning, hierarchical
@@ -116,7 +119,7 @@
framework="torch",
),
},
- "log_level": "DEBUG",
+ "batch_mode": "complete_episodes",
}
)
@@ -133,7 +136,7 @@
PLOT_ASSEMBLAGE_TAGS = [
("defection",),
- ("defection", "chosen_percentile"),
+ ("defection_", "chosen_percentile"),
("punish",),
("delta_log_likelihood_coop_std", "delta_log_likelihood_coop_mean"),
("likelihood_opponent", "likelihood_approximated_opponent"),
@@ -158,7 +161,15 @@ def execution_plan(
workers: WorkerSet, config: TrainerConfigDict
) -> LocalIterator[dict]:
"""
- Modified from the execution plan of the DQNTrainer
+ Execution plan of the DQN algorithm. Defines the distributed dataflow.
+
+ Args:
+ workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
+ of the Trainer.
+ config (TrainerConfigDict): The trainer's configuration dict.
+
+ Returns:
+ LocalIterator[dict]: A local iterator over training metrics.
"""
if config.get("prioritized_replay"):
prio_args = {
@@ -175,7 +186,9 @@ def execution_plan(
buffer_size=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
- replay_sequence_length=config["replay_sequence_length"],
+ replay_sequence_length=config.get("replay_sequence_length", 1),
+ replay_burn_in=config.get("burn_in", 0),
+ replay_zero_init_states=config.get("zero_init_states", True),
**prio_args,
)
@@ -204,10 +217,9 @@ def update_prio(item):
td_error = info.get(
"td_error", info[LEARNER_STATS_KEY].get("td_error")
)
+ samples.policy_batches[policy_id].set_get_interceptor(None)
prio_dict[policy_id] = (
- samples.policy_batches[policy_id].data.get(
- "batch_indexes"
- ),
+ samples.policy_batches[policy_id].get("batch_indexes"),
td_error,
)
local_replay_buffer.update_priorities(prio_dict)
@@ -217,11 +229,25 @@ def update_prio(item):
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step, and then we decide whether to update the target network.
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
+
+ if config["simple_optimizer"]:
+ train_step_op = TrainOneStep(workers)
+ else:
+ train_step_op = TrainTFMultiGPU(
+ workers=workers,
+ sgd_minibatch_size=config["train_batch_size"],
+ num_sgd_iter=1,
+ num_gpus=config["num_gpus"],
+ shuffle_sequences=True,
+ _fake_gpus=config["_fake_gpus"],
+ framework=config.get("framework"),
+ )
+
replay_op_dqn = (
Replay(local_buffer=local_replay_buffer)
.for_each(lambda x: post_fn(x, workers, config))
.for_each(LocalTrainablePolicyModifier(workers, train_dqn_only))
- .for_each(TrainOneStep(workers))
+ .for_each(train_step_op)
.for_each(update_prio)
.for_each(
UpdateTargetNetwork(workers, config["target_network_update_freq"])
diff --git a/marltoolbox/algos/ltft/ltft_torch_policy.py b/marltoolbox/algos/ltft/ltft_torch_policy.py
index 05ce962..6101e5d 100644
--- a/marltoolbox/algos/ltft/ltft_torch_policy.py
+++ b/marltoolbox/algos/ltft/ltft_torch_policy.py
@@ -8,7 +8,7 @@
import copy
import logging
from collections import deque
-from typing import List, Union, Optional, Dict, Tuple, TYPE_CHECKING
+from typing import Optional, Dict, Tuple, TYPE_CHECKING, Union, List
import numpy as np
import torch
@@ -16,9 +16,11 @@
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
+from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import AgentID, PolicyID, TensorType
+from ray.rllib.utils.torch_ops import convert_to_torch_tensor
if TYPE_CHECKING:
from ray.rllib.evaluation import RolloutWorker
@@ -45,6 +47,9 @@
},
)
+BEING_PUNISHED = "being_punished"
+PUNISHING = "punishing"
+
class LTFTTorchPolicy(hierarchical.HierarchicalTorchPolicy):
"""
@@ -63,12 +68,10 @@ class LTFTTorchPolicy(hierarchical.HierarchicalTorchPolicy):
INITIALLY_ACTIVE_ALGO = COOP_POLICY_IDX
def __init__(self, observation_space, action_space, config, **kwargs):
-
super().__init__(
observation_space,
action_space,
config,
- # after_init_nested=_init_weights_fn,
**kwargs,
)
@@ -100,10 +103,8 @@ def __init__(self, observation_space, action_space, config, **kwargs):
self.n_steps_since_start = 0
self.last_computed_w = None
- # self._episode_starting = True
self.opp_previous_obs = None
self.opp_new_obs = None
- self._first_fake_step_played = False
self.observed_n_step_in_current_epi = 0
self.data_queue = deque(maxlen=self.length_of_history)
@@ -140,47 +141,39 @@ def __init__(self, observation_space, action_space, config, **kwargs):
add_opponent_neg_reward=True,
)
+ self.view_requirements.update(
+ {
+ PUNISHING: ViewRequirement(),
+ }
+ )
+
def _reset_learner_stats(self):
self.learner_stats = {"learner_stats": {}}
@override(hierarchical.HierarchicalTorchPolicy)
- def compute_actions(
- self,
- obs_batch: Union[List[TensorType], TensorType],
- state_batches: Optional[List[TensorType]] = None,
- prev_action_batch: Union[List[TensorType], TensorType] = None,
- prev_reward_batch: Union[List[TensorType], TensorType] = None,
- info_batch: Optional[Dict[str, list]] = None,
- episodes: Optional[List["MultiAgentEpisode"]] = None,
- explore: Optional[bool] = None,
- timestep: Optional[int] = None,
- **kwargs,
- ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
- assert (
- len(obs_batch) == 1
- ), "LE only works with sampling one step at a time"
-
- actions, state_out, extra_fetches = super().compute_actions(
- obs_batch,
- state_batches,
- prev_action_batch,
- prev_reward_batch,
- info_batch,
- episodes,
- explore,
- timestep,
- **kwargs,
+ def _compute_action_helper(
+ self, input_dict, state_batches, seq_lens, explore, timestep
+ ):
+ obs_batch = input_dict["obs"]
+ assert len(obs_batch) == 1, (
+ "The LTFT policy only works when sampling/infering one step "
+ "at a time."
)
- extra_fetches["punishing"] = [self._is_punishing() for _ in obs_batch]
- return actions, [], extra_fetches
+ actions, state_out, extra_fetches = super()._compute_action_helper(
+ input_dict, state_batches, seq_lens, explore, timestep
+ )
+ extra_fetches[PUNISHING] = [
+ self._is_punishing() for _ in obs_batch.tolist()
+ ]
+
+ return actions, state_out, extra_fetches
def _is_punishing(self):
return self.active_algo_idx == self.PUNITIVE_POLICY_IDX
@override(hierarchical.HierarchicalTorchPolicy)
def _learn_on_batch(self, samples: SampleBatch):
-
(
policies_idx_to_train,
policies_to_train,
@@ -237,14 +230,14 @@ def _modify_batch_for_policy(self, policy_n, samples_copy):
)
if policy_n == self.COOP_POLICY_IDX:
samples_copy = filter_sample_batch(
- samples_copy, remove=True, filter_key="punishing"
+ samples_copy, remove=True, filter_key=PUNISHING
)
elif policy_n == self.COOP_OPP_POLICY_IDX:
samples_copy[samples_copy.ACTIONS] = np.array(
samples_copy.data[postprocessing.OPPONENT_ACTIONS]
)
samples_copy = filter_sample_batch(
- samples_copy, remove=True, filter_key="being_punished"
+ samples_copy, remove=True, filter_key=BEING_PUNISHED
)
else:
raise ValueError()
@@ -260,7 +253,7 @@ def _modify_batch_for_policy(self, policy_n, samples_copy):
# this policy. This is because the opponent is not using its
# cooperative policy at that time.
samples_copy = filter_sample_batch(
- samples_copy, remove=True, filter_key="being_punished"
+ samples_copy, remove=True, filter_key=BEING_PUNISHED
)
else:
raise ValueError()
@@ -268,30 +261,25 @@ def _modify_batch_for_policy(self, policy_n, samples_copy):
return samples_copy
def on_episode_step(self, episode, policy_id, policy_ids, *args, **kwargs):
- if self._first_fake_step_played:
- (
- opp_previous_obs,
- opp_a,
- being_punished_by_opp,
- ) = self._get_information_from_opponent(
- episode, policy_id, policy_ids
+ (
+ opp_previous_obs,
+ opp_a,
+ being_punished_by_opp,
+ ) = self._get_information_from_opponent(episode, policy_id, policy_ids)
+
+ self.being_punished_by_opp = being_punished_by_opp
+ if not self.being_punished_by_opp:
+ self.n_steps_since_start += 1
+ self._put_log_likelihood_in_data_buffer(
+ opp_previous_obs, opp_a, self.data_queue
)
- self.being_punished_by_opp = being_punished_by_opp
- if not self.being_punished_by_opp:
- self.n_steps_since_start += 1
- self._put_log_likelihood_in_data_buffer(
- opp_previous_obs, opp_a, self.data_queue
- )
-
- if self.remaining_punishing_time > 0:
- self.n_punishment_steps_in_current_epi += 1
- else:
- self.n_cooperation_steps_in_current_epi += 1
-
- self.observed_n_step_in_current_epi += 1
+ if self.remaining_punishing_time > 0:
+ self.n_punishment_steps_in_current_epi += 1
else:
- self._first_fake_step_played = True
+ self.n_cooperation_steps_in_current_epi += 1
+
+ self.observed_n_step_in_current_epi += 1
def _get_information_from_opponent(self, episode, agent_id, agent_ids):
opp_agent_id = [one_id for one_id in agent_ids if one_id != agent_id][
@@ -299,7 +287,7 @@ def _get_information_from_opponent(self, episode, agent_id, agent_ids):
]
opp_a = episode.last_action_for(opp_agent_id)
being_punished_by_opp = episode.last_pi_info_for(opp_agent_id).get(
- "punishing", False
+ PUNISHING, False
)
return self.opp_previous_obs, opp_a, being_punished_by_opp
@@ -308,8 +296,7 @@ def on_observation_fn(self, opp_new_obs):
# observation produced by this action. But we need the
# observation that cause the agent to play this action
# thus the observation n-1
- if self._first_fake_step_played:
- self.opp_previous_obs = self.opp_new_obs
+ self.opp_previous_obs = self.opp_new_obs
self.opp_new_obs = opp_new_obs
def on_episode_end(self, base_env, *args, **kwargs):
@@ -386,15 +373,11 @@ def postprocess_trajectory(
def _put_log_likelihood_in_data_buffer(self, s, a, data_queue, log=True):
s = torch.from_numpy(s).unsqueeze(dim=0)
- log_likelihood_opponent_cooperating = (
- compute_log_likelihoods_wt_exploration(
- self.algorithms[self.COOP_OPP_POLICY_IDX], a, s
- )
+ log_likelihood_opponent_cooperating = compute_log_likelihoods(
+ self.algorithms[self.COOP_OPP_POLICY_IDX], [a], s
)
- log_likelihood_approximated_opponent = (
- compute_log_likelihoods_wt_exploration(
- self.algorithms[self.SPL_OPP_POLICY_IDX], a, s
- )
+ log_likelihood_approximated_opponent = compute_log_likelihoods(
+ self.algorithms[self.SPL_OPP_POLICY_IDX], [a], s
)
log_likelihood_opponent_cooperating = float(
log_likelihood_opponent_cooperating
@@ -500,20 +483,20 @@ def on_postprocess_trajectory(
opp_policy_id = all_agent_keys[0]
opp_is_broadcast_punishment_state = (
- "punishing" in original_batches[opp_policy_id][1].data.keys()
+ PUNISHING in original_batches[opp_policy_id][1].keys()
)
if opp_is_broadcast_punishment_state:
- postprocessed_batch.data["being_punished"] = copy.deepcopy(
- original_batches[all_agent_keys[0]][1].data["punishing"]
+ postprocessed_batch.data[BEING_PUNISHED] = copy.deepcopy(
+ original_batches[all_agent_keys[0]][1][PUNISHING]
)
else:
- postprocessed_batch.data["being_punished"] = [False] * len(
+ postprocessed_batch.data[BEING_PUNISHED] = [False] * len(
postprocessed_batch[postprocessed_batch.OBS]
)
-# Modified from torch_policy_template
-def compute_log_likelihoods_wt_exploration(
+# Modified from torch_policy
+def compute_log_likelihoods(
policy,
actions: Union[List[TensorType], TensorType],
obs_batch: Union[List[TensorType], TensorType],
@@ -521,6 +504,7 @@ def compute_log_likelihoods_wt_exploration(
prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None,
prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None,
) -> TensorType:
+
if policy.action_sampler_fn and policy.action_distribution_fn is None:
raise ValueError(
"Cannot compute log-prob/likelihood w/o an "
@@ -537,34 +521,63 @@ def compute_log_likelihoods_wt_exploration(
if prev_reward_batch is not None:
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
+ state_batches = [
+ convert_to_torch_tensor(s, policy.device)
+ for s in (state_batches or [])
+ ]
# Exploration hook before each forward pass.
policy.exploration.before_compute_actions(explore=False)
# Action dist class and inputs are generated via custom function.
if policy.action_distribution_fn:
- dist_inputs, dist_class, _ = policy.action_distribution_fn(
- policy=policy,
- model=policy.model,
- obs_batch=input_dict[SampleBatch.CUR_OBS],
- explore=False,
- is_training=False,
- )
+
+ # Try new action_distribution_fn signature, supporting
+ # state_batches and seq_lens.
+ try:
+ (dist_inputs, dist_class, _,) = policy.action_distribution_fn(
+ policy,
+ policy.model,
+ input_dict=input_dict,
+ state_batches=state_batches,
+ seq_lens=seq_lens,
+ explore=False,
+ is_training=False,
+ )
+ # Trying the old way (to stay backward compatible).
+ # TODO: Remove in future.
+ except TypeError as e:
+ if (
+ "positional argument" in e.args[0]
+ or "unexpected keyword argument" in e.args[0]
+ ):
+ dist_inputs, dist_class, _ = policy.action_distribution_fn(
+ policy=policy,
+ model=policy.model,
+ obs_batch=input_dict[SampleBatch.CUR_OBS],
+ explore=False,
+ is_training=False,
+ )
+ else:
+ raise e
+
# Default action-dist inputs calculation.
else:
dist_class = policy.dist_class
dist_inputs, _ = policy.model(input_dict, state_batches, seq_lens)
action_dist = dist_class(dist_inputs, policy.model)
+ # ADDITION 1/1 STARTS
if policy.config["explore"]:
# Adding that because of a bug in TorchCategorical
# which modify dist_inputs through action_dist:
- _, _ = policy.exploration.get_exploration_action(
+ _ = policy.exploration.get_exploration_action(
action_distribution=action_dist,
timestep=policy.global_timestep,
explore=policy.config["explore"],
)
action_dist = dist_class(dist_inputs, policy.model)
+ # ADDITION 1/1 ENDS
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])
diff --git a/marltoolbox/algos/population.py b/marltoolbox/algos/population.py
index 81cb229..ea5e0b2 100644
--- a/marltoolbox/algos/population.py
+++ b/marltoolbox/algos/population.py
@@ -1,12 +1,9 @@
import copy
import random
-from typing import List, Union, Optional, Dict, Tuple
-from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import merge_dicts
-from ray.rllib.utils.annotations import override
-from ray.rllib.utils.typing import TensorType
+from ray.rllib.utils import override
from marltoolbox.algos import hierarchical
from marltoolbox.utils import miscellaneous, restore
@@ -17,7 +14,7 @@
# To configure
"policy_checkpoints": [],
"policy_id_to_load": None,
- "nested_policies": None,
+ # "nested_policies": None,
"freeze_algo": True,
"use_random_algo": True,
"use_algo_in_order": False,
@@ -54,12 +51,10 @@ def set_algo_to_use(self):
"""
Called by a callback at the start of every episode.
"""
-
if self.switch_of_algo_counter == self.switch_of_algo_every_n_epi:
self.switch_of_algo_counter = 0
self._set_algo_to_use()
- else:
- self.switch_of_algo_counter += 1
+ self.switch_of_algo_counter += 1
def _set_algo_to_use(self):
self._select_algo_idx_to_use()
@@ -75,10 +70,17 @@ def _select_algo_idx_to_use(self):
else:
raise ValueError()
+ self._to_log[
+ "PopulationOfIdenticalAlgo_active_checkpoint_idx"
+ ] = self.active_checkpoint_idx
+
def _use_random_algo(self):
self.active_checkpoint_idx = random.randint(
0, len(self.policy_checkpoints) - 1
)
+ self._to_log[
+ "PopulationOfIdenticalAlgo_active_checkpoint_idx"
+ ] = self.active_checkpoint_idx
def _use_next_algo(self):
self.active_checkpoint_idx += 1
@@ -110,22 +112,7 @@ def set_weights(self, weights):
self.algorithms[self.active_algo_idx].set_weights(weights)
@override(hierarchical.HierarchicalTorchPolicy)
- def compute_actions(
- self,
- obs_batch: Union[List[TensorType], TensorType],
- state_batches: Optional[List[TensorType]] = None,
- prev_action_batch: Union[List[TensorType], TensorType] = None,
- prev_reward_batch: Union[List[TensorType], TensorType] = None,
- info_batch: Optional[Dict[str, list]] = None,
- episodes: Optional[List["MultiAgentEpisode"]] = None,
- explore: Optional[bool] = None,
- timestep: Optional[int] = None,
- **kwargs,
- ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
- return self.algorithms[self.active_algo_idx].compute_actions(obs_batch)
-
- @override(hierarchical.HierarchicalTorchPolicy)
- def learn_on_batch(self, samples: SampleBatch):
+ def _learn_on_batch(self, samples: SampleBatch):
if not self.freeze_algo:
# TODO maybe need to call optimizer to update the LR of the nested optimizers
learner_stats = {"learner_stats": {}}
@@ -165,6 +152,11 @@ def on_episode_start(
**kwargs,
):
self.set_algo_to_use()
+ if hasattr(self.algorithms[self.active_algo_idx], "on_episode_start"):
+ self.algorithms[self.active_algo_idx].on_episode_start(
+ *args,
+ **kwargs,
+ )
def modify_config_to_use_population(
diff --git a/marltoolbox/algos/sos.py b/marltoolbox/algos/sos.py
new file mode 100644
index 0000000..dbf8079
--- /dev/null
+++ b/marltoolbox/algos/sos.py
@@ -0,0 +1,521 @@
+######
+# Code modified from:
+# https://github.com/julianstastny/openspiel-social-dilemmas
+######
+import os
+import random
+
+import torch
+import numpy as np
+from ray import tune
+
+from marltoolbox.envs.matrix_sequential_social_dilemma import (
+ IteratedPrisonersDilemma,
+ IteratedAsymBoS,
+)
+
+
+class SOSTrainer(tune.Trainable):
+ def setup(self, config: dict):
+
+ self.config = config
+ self.gamma = self.config.get("gamma")
+ self.learning_rate = self.config.get("lr")
+ self.method = self.config.get("method")
+ self.use_single_weights = False
+
+ self._set_environment()
+ self.init_weigths(std=self.config.get("inital_weights_std"))
+ if self.use_single_weights:
+ self._exact_loss_matrix_game = (
+ self._exact_loss_matrix_game_two_by_two_actions
+ )
+ else:
+ self._exact_loss_matrix_game = self._exact_loss_matrix_game_generic
+
+ def _set_environment(self):
+
+ if self.config.get("env_name") == "IteratedPrisonersDilemma":
+ env_class = IteratedPrisonersDilemma
+ payoff_matrix = env_class.PAYOFF_MATRIX
+ self.players_ids = env_class({}).players_ids
+ elif self.config.get("env_name") == "IteratedAsymBoS":
+ env_class = IteratedAsymBoS
+ payoff_matrix = env_class.PAYOFF_MATRIX
+ self.players_ids = env_class({}).players_ids
+ elif "custom_payoff_matrix" in self.config.keys():
+ payoff_matrix = self.config["custom_payoff_matrix"]
+ self.players_ids = IteratedPrisonersDilemma({}).players_ids
+ else:
+ raise NotImplementedError()
+
+ self.n_actions_p1 = np.array(payoff_matrix).shape[0]
+ self.n_actions_p2 = np.array(payoff_matrix).shape[1]
+ self.n_non_init_states = self.n_actions_p1 * self.n_actions_p2
+
+ if self.use_single_weights:
+ self.dims = [
+ (self.n_non_init_states + 1,),
+ (self.n_non_init_states + 1,),
+ ]
+ else:
+ self.dims = [
+ (self.n_non_init_states + 1, self.n_actions_p1),
+ (self.n_non_init_states + 1, self.n_actions_p2),
+ ]
+
+ self.payoff_matrix_player_row = torch.tensor(
+ payoff_matrix[:, :, 0]
+ ).float()
+ self.payoff_matrix_player_col = torch.tensor(
+ payoff_matrix[:, :, 1]
+ ).float()
+
+ def step(self):
+ losses = self.update_th()
+ mean_reward_player_row = -losses[0] * (1 - self.gamma)
+ mean_reward_player_col = -losses[1] * (1 - self.gamma)
+ to_log = {
+ f"mean_reward_{self.players_ids[0]}": mean_reward_player_row,
+ f"mean_reward_{self.players_ids[1]}": mean_reward_player_col,
+ "episodes_total": self.training_iteration,
+ }
+ return to_log
+
+ def _exact_loss_matrix_game_two_by_two_actions(self):
+
+ pi_player_row_init_state = torch.sigmoid(
+ self.weights_per_players[0][0:1]
+ )
+ pi_player_col_init_state = torch.sigmoid(
+ self.weights_per_players[1][0:1]
+ )
+ p = torch.cat(
+ [
+ pi_player_row_init_state * pi_player_col_init_state,
+ pi_player_row_init_state * (1 - pi_player_col_init_state),
+ (1 - pi_player_row_init_state) * pi_player_col_init_state,
+ (1 - pi_player_row_init_state)
+ * (1 - pi_player_col_init_state),
+ ]
+ )
+ pi_player_row_other_states = torch.reshape(
+ torch.sigmoid(self.weights_per_players[0][1:5]), (4, 1)
+ )
+ pi_player_col_other_states = torch.reshape(
+ torch.sigmoid(self.weights_per_players[1][1:5]), (4, 1)
+ )
+ P = torch.cat(
+ [
+ pi_player_row_other_states * pi_player_col_other_states,
+ pi_player_row_other_states * (1 - pi_player_col_other_states),
+ (1 - pi_player_row_other_states) * pi_player_col_other_states,
+ (1 - pi_player_row_other_states)
+ * (1 - pi_player_col_other_states),
+ ],
+ dim=1,
+ )
+ M = -torch.matmul(p, torch.inverse(torch.eye(4) - self.gamma * P))
+ L_1 = torch.matmul(
+ M, torch.reshape(self.payoff_matrix_player_row, (4, 1))
+ )
+ L_2 = torch.matmul(
+ M, torch.reshape(self.payoff_matrix_player_col, (4, 1))
+ )
+ return [L_1, L_2]
+
+ def _exact_loss_matrix_game_generic(self):
+ pi_player_row = torch.sigmoid(self.weights_per_players[0])
+ pi_player_col = torch.sigmoid(self.weights_per_players[1])
+ sum_1 = torch.sum(pi_player_row, dim=1)
+ sum_1 = torch.stack([sum_1 for _ in range(self.n_actions_p1)], dim=1)
+ sum_2 = torch.sum(pi_player_col, dim=1)
+ sum_2 = torch.stack([sum_2 for _ in range(self.n_actions_p2)], dim=1)
+ pi_player_row = pi_player_row / sum_1
+ pi_player_col = pi_player_col / sum_2
+
+ pi_player_row_init_state = pi_player_row[:1, :]
+ pi_player_col_init_state = pi_player_col[:1, :]
+ all_initial_actions_proba_pairs = []
+ for action_p1 in range(self.n_actions_p1):
+ for action_p2 in range(self.n_actions_p2):
+ all_initial_actions_proba_pairs.append(
+ pi_player_row_init_state[:, action_p1]
+ * pi_player_col_init_state[:, action_p2]
+ )
+ p = torch.cat(
+ all_initial_actions_proba_pairs,
+ )
+
+ pi_player_row_other_states = pi_player_row[1:, :]
+ pi_player_col_other_states = pi_player_col[1:, :]
+ all_actions_proba_pairs = []
+ for action_p1 in range(self.n_actions_p1):
+ for action_p2 in range(self.n_actions_p2):
+ all_actions_proba_pairs.append(
+ pi_player_row_other_states[:, action_p1]
+ * pi_player_col_other_states[:, action_p2]
+ )
+ P = torch.stack(
+ all_actions_proba_pairs,
+ 1,
+ )
+
+ M = -torch.matmul(
+ p,
+ torch.inverse(torch.eye(self.n_non_init_states) - self.gamma * P),
+ )
+ L_1 = torch.matmul(
+ M,
+ torch.reshape(
+ self.payoff_matrix_player_row, (self.n_non_init_states, 1)
+ ),
+ )
+ L_2 = torch.matmul(
+ M,
+ torch.reshape(
+ self.payoff_matrix_player_col, (self.n_non_init_states, 1)
+ ),
+ )
+ return [L_1, L_2]
+
+ def init_weigths(self, std):
+ self.weights_per_players = []
+ self.n_players = len(self.dims)
+ for i in range(self.n_players):
+ if std > 0:
+ init = torch.nn.init.normal_(
+ torch.empty(self.dims[i], requires_grad=True), std=std
+ )
+ else:
+ init = torch.zeros(self.dims[i], requires_grad=True)
+ self.weights_per_players.append(init)
+
+ def update_th(
+ self,
+ a=0.5,
+ b=0.1,
+ gam=1,
+ ep=0.1,
+ lss_lam=0.1,
+ ):
+ losses = self._exact_loss_matrix_game()
+
+ grads = self._compute_gradients_wt_selected_method(
+ a,
+ b,
+ ep,
+ gam,
+ losses,
+ lss_lam,
+ )
+
+ self._update_weights(grads)
+ return losses
+
+ def _compute_gradients_wt_selected_method(
+ self, a, b, ep, gam, losses, lss_lam
+ ):
+ n_players = self.n_players
+
+ grad_L = _compute_vanilla_gradients(
+ losses, n_players, self.weights_per_players
+ )
+
+ if self.method == "la":
+ grads = _get_la_gradients(
+ self.learning_rate, grad_L, n_players, self.weights_per_players
+ )
+ elif self.method == "lola":
+ grads = _get_lola_exact_gradients(
+ self.learning_rate, grad_L, n_players, self.weights_per_players
+ )
+ elif self.method == "sos":
+ grads = _get_sos_gradients(
+ a,
+ self.learning_rate,
+ b,
+ grad_L,
+ n_players,
+ self.weights_per_players,
+ )
+ elif self.method == "sga":
+ grads = _get_sga_gradients(
+ ep, grad_L, n_players, self.weights_per_players
+ )
+ elif self.method == "co":
+ grads = _get_co_gradients(
+ gam, grad_L, n_players, self.weights_per_players
+ )
+ elif self.method == "eg":
+ grads = _get_eg_gradients(
+ self._exact_loss_matrix_game,
+ self.learning_rate,
+ losses,
+ n_players,
+ self.weights_per_players,
+ )
+ elif self.method == "cgd": # Slow implementation (matrix inversion)
+ grads = _get_cgd_gradients(
+ self.learning_rate, grad_L, n_players, self.weights_per_players
+ )
+ elif self.method == "lss": # Slow implementation (matrix inversion)
+ grads = _get_lss_gradients(
+ grad_L, lss_lam, n_players, self.weights_per_players
+ )
+ elif self.method == "naive": # Naive Learning
+ grads = _get_naive_learning_gradients(grad_L, n_players)
+ else:
+ raise ValueError(f"algo: {self.method}")
+ return grads
+
+ def _update_weights(self, grads):
+ with torch.no_grad():
+ for weight_i, weight_grad in enumerate(grads):
+ self.weights_per_players[weight_i] -= (
+ self.learning_rate * weight_grad
+ )
+
+ def save_checkpoint(self, checkpoint_dir):
+ save_path = os.path.join(checkpoint_dir, "weights.pt")
+ torch.save(self.weights_per_players, save_path)
+ return save_path
+
+ def load_checkpoint(self, checkpoint_path):
+ self.weights_per_players = torch.load(checkpoint_path)
+
+ def _get_agent_to_use(self, policy_id):
+ if policy_id == "player_row":
+ agent_n = 0
+ elif policy_id == "player_col":
+ agent_n = 1
+ else:
+ raise ValueError(f"policy_id {policy_id}")
+ return agent_n
+
+ def _preprocess_obs(self, single_obs, agent_to_use):
+ single_obs = np.where(single_obs == 1)[0][0]
+ # because idx 0 is linked to the initial state in the weights
+ # but this is the last idx which is linked to the
+ # initial state in the environment obs
+ if single_obs == len(self.weights_per_players[0]) - 1:
+ single_obs = 0
+ else:
+ single_obs += 1
+ return single_obs
+
+ def _post_process_action(self, action):
+ return action[None, ...] # add batch dim
+
+ def compute_actions(self, policy_id: str, obs_batch: list):
+ assert (
+ len(obs_batch) == 1
+ ), f"{len(obs_batch)} == 1. obs_batch: {obs_batch}"
+
+ for single_obs in obs_batch:
+ agent_to_use = self._get_agent_to_use(policy_id)
+ obs = self._preprocess_obs(single_obs, agent_to_use)
+ policy = torch.sigmoid(self.weights_per_players[agent_to_use])
+
+ if self.use_single_weights:
+ coop_proba = policy[obs]
+ if coop_proba > random.random():
+ action = np.array(0)
+ else:
+ action = np.array(1)
+ else:
+ probabilities = policy[obs, :]
+ probabilities = torch.tensor(probabilities)
+ policy_for_this_state = torch.distributions.Categorical(
+ probs=probabilities
+ )
+ action = np.array(policy_for_this_state.sample())
+
+ action = self._post_process_action(action)
+
+ state_out = []
+ extra_fetches = {}
+ return action, state_out, extra_fetches
+
+
+def _compute_vanilla_gradients(losses, n, th):
+ grad_L = [
+ [get_gradient(losses[j], th[i]) for j in range(n)] for i in range(n)
+ ]
+ return grad_L
+
+
+def _get_la_gradients(alpha, grad_L, n, th):
+ terms = [
+ sum(
+ [
+ # torch.dot(grad_L[j][i], grad_L[j][j].detach())
+ torch.matmul(
+ grad_L[j][i],
+ torch.transpose(grad_L[j][j].detach(), 0, 1),
+ )
+ for j in range(n)
+ if j != i
+ ]
+ )
+ for i in range(n)
+ ]
+ grads = [
+ grad_L[i][i] - alpha * get_gradient(terms[i], th[i]) for i in range(n)
+ ]
+ return grads
+
+
+def _get_lola_exact_gradients(alpha, grad_L, n, th):
+ terms = [
+ sum(
+ # [torch.dot(grad_L[j][i], grad_L[j][j]) for j in range(n) if j != i]
+ [
+ torch.matmul(
+ grad_L[j][i],
+ torch.transpose(grad_L[j][j].unsqueeze(dim=1), 0, 1),
+ )
+ for j in range(n)
+ if j != i
+ ]
+ )
+ for i in range(n)
+ ]
+ grads = [
+ grad_L[i][i] - alpha * get_gradient(terms[i], th[i]) for i in range(n)
+ ]
+ return grads
+
+
+def _get_sos_gradients(a, alpha, b, grad_L, n, th):
+ xi_0 = _get_la_gradients(alpha, grad_L, n, th)
+ chi = [
+ get_gradient(
+ sum(
+ [
+ # torch.dot(grad_L[j][i].detach(), grad_L[j][j])
+ torch.matmul(
+ grad_L[j][i].detach(),
+ torch.transpose(grad_L[j][j], 0, 1),
+ )
+ for j in range(n)
+ if j != i
+ ]
+ ),
+ th[i],
+ )
+ for i in range(n)
+ ]
+ # Compute p
+ # dot = torch.dot(-alpha * torch.cat(chi), torch.cat(xi_0))
+ dot = torch.matmul(
+ -alpha * torch.cat(chi),
+ torch.transpose(torch.cat(xi_0), 0, 1),
+ ).sum()
+ p1 = 1 if dot >= 0 else min(1, -a * torch.norm(torch.cat(xi_0)) ** 2 / dot)
+ xi = torch.cat([grad_L[i][i] for i in range(n)])
+ xi_norm = torch.norm(xi)
+ p2 = xi_norm ** 2 if xi_norm < b else 1
+ p = min(p1, p2)
+ grads = [xi_0[i] - p * alpha * chi[i] for i in range(n)]
+ return grads
+
+
+def _get_sga_gradients(ep, grad_L, n, th):
+ xi = torch.cat([grad_L[i][i] for i in range(n)])
+ ham = torch.dot(xi, xi.detach())
+ H_t_xi = [get_gradient(ham, th[i]) for i in range(n)]
+ H_xi = [
+ get_gradient(
+ sum(
+ [
+ torch.dot(grad_L[j][i], grad_L[j][j].detach())
+ for j in range(n)
+ ]
+ ),
+ th[i],
+ )
+ for i in range(n)
+ ]
+ A_t_xi = [H_t_xi[i] / 2 - H_xi[i] / 2 for i in range(n)]
+ # Compute lambda (sga with alignment)
+ dot_xi = torch.dot(xi, torch.cat(H_t_xi))
+ dot_A = torch.dot(torch.cat(A_t_xi), torch.cat(H_t_xi))
+ d = sum([len(th[i]) for i in range(n)])
+ lam = torch.sign(dot_xi * dot_A / d + ep)
+ grads = [grad_L[i][i] + lam * A_t_xi[i] for i in range(n)]
+ return grads
+
+
+def _get_co_gradients(gam, grad_L, n, th):
+ xi = torch.cat([grad_L[i][i] for i in range(n)])
+ ham = torch.dot(xi, xi.detach())
+ grads = [grad_L[i][i] + gam * get_gradient(ham, th[i]) for i in range(n)]
+ return grads
+
+
+def _get_eg_gradients(Ls, alpha, losses, n, th):
+ th_eg = [th[i] - alpha * get_gradient(losses[i], th[i]) for i in range(n)]
+ losses_eg = Ls(th_eg)
+ grads = [get_gradient(losses_eg[i], th_eg[i]) for i in range(n)]
+ return grads
+
+
+def _get_naive_learning_gradients(grad_L, n):
+ grads = [grad_L[i][i] for i in range(n)]
+ return grads
+
+
+def _get_lss_gradients(grad_L, lss_lam, n, th):
+ dims = [len(th[i]) for i in range(n)]
+ xi = torch.cat([grad_L[i][i] for i in range(n)])
+ H = get_hessian(th, grad_L)
+ if torch.det(H) == 0:
+ inv = torch.inverse(
+ torch.matmul(H.T, H) + lss_lam * torch.eye(sum(dims))
+ )
+ H_inv = torch.matmul(inv, H.T)
+ else:
+ H_inv = torch.inverse(H)
+ grad = (
+ torch.matmul(torch.eye(sum(dims)) + torch.matmul(H.T, H_inv), xi) / 2
+ )
+ grads = [grad[sum(dims[:i]) : sum(dims[: i + 1])] for i in range(n)]
+ return grads
+
+
+def _get_cgd_gradients(alpha, grad_L, n, th):
+ dims = [len(th[i]) for i in range(n)]
+ xi = torch.cat([grad_L[i][i] for i in range(n)])
+ H_o = get_hessian(th, grad_L, diag=False)
+ grad = torch.matmul(torch.inverse(torch.eye(sum(dims)) + alpha * H_o), xi)
+ grads = [grad[sum(dims[:i]) : sum(dims[: i + 1])] for i in range(n)]
+ return grads
+
+
+def get_gradient(function, param):
+ grad = torch.autograd.grad(torch.sum(function), param, create_graph=True)[
+ 0
+ ]
+ return grad
+
+
+def get_hessian(th, grad_L, diag=True, off_diag=True):
+ n = len(th)
+ H = []
+ for i in range(n):
+ row_block = []
+ for j in range(n):
+ if (i == j and diag) or (i != j and off_diag):
+ block = [
+ torch.unsqueeze(
+ get_gradient(grad_L[i][i][k], th[j]),
+ dim=0,
+ )
+ for k in range(len(th[i]))
+ ]
+ row_block.append(torch.cat(block, dim=0))
+ else:
+ row_block.append(torch.zeros(len(th[i]), len(th[j])))
+ H.append(torch.cat(row_block, dim=1))
+ return torch.cat(H, dim=0)
diff --git a/marltoolbox/algos/stochastic_population.py b/marltoolbox/algos/stochastic_population.py
new file mode 100644
index 0000000..6a76b12
--- /dev/null
+++ b/marltoolbox/algos/stochastic_population.py
@@ -0,0 +1,33 @@
+import torch
+from ray.rllib import SampleBatch
+from ray.rllib.utils import override
+
+from marltoolbox.algos.amTFT import base
+from marltoolbox.algos.hierarchical import HierarchicalTorchPolicy
+
+
+class StochasticPopulation(HierarchicalTorchPolicy, base.AmTFTReferenceClass):
+ def __init__(self, observation_space, action_space, config, **kwargs):
+ super().__init__(observation_space, action_space, config, **kwargs)
+ self.stochastic_population_policy = torch.distributions.Categorical(
+ probs=self.config["sampling_policy_distribution"]
+ )
+
+ def on_episode_start(self, *args, **kwargs):
+ self._select_algo_idx_to_use()
+ if hasattr(self.algorithms[self.active_algo_idx], "on_episode_start"):
+ self.algorithms[self.active_algo_idx].on_episode_start(
+ *args,
+ **kwargs,
+ )
+
+ def _select_algo_idx_to_use(self):
+ policy_idx_selected = self.stochastic_population_policy.sample()
+ self.active_algo_idx = policy_idx_selected
+ self._to_log[
+ "StochasticPopulation_active_algo_idx"
+ ] = self.active_algo_idx
+
+ @override(HierarchicalTorchPolicy)
+ def _learn_on_batch(self, samples: SampleBatch):
+ return self.algorithms[self.active_algo_idx]._learn_on_batch(samples)
diff --git a/marltoolbox/algos/welfare_coordination.py b/marltoolbox/algos/welfare_coordination.py
index 559bc83..13516a3 100644
--- a/marltoolbox/algos/welfare_coordination.py
+++ b/marltoolbox/algos/welfare_coordination.py
@@ -2,6 +2,7 @@
import logging
import random
+import torch
import numpy as np
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override
@@ -27,6 +28,7 @@
"opp_player_idx": None,
"own_default_welfare_fn": None,
"opp_default_welfare_fn": None,
+ "distrib_over_welfare_sets_to_annonce": False,
},
)
@@ -60,20 +62,25 @@ def setup_meta_game(
self.opp_player_idx = opp_player_idx
self.own_default_welfare_fn = own_default_welfare_fn
self.opp_default_welfare_fn = opp_default_welfare_fn
- self._list_all_welfares_fn()
+ self.all_welfare_fn = MetaGameSolver.list_all_welfares_fn(
+ self.all_welfare_pairs_wt_payoffs
+ )
self._list_all_set_of_welfare_fn()
- def _list_all_welfares_fn(self):
+ @staticmethod
+ def list_all_welfares_fn(all_welfare_pairs_wt_payoffs):
+ """List all welfare functions in the given data structure"""
all_welfare_fn = []
- for welfare_pairs_key in self.all_welfare_pairs_wt_payoffs.keys():
- for welfare_fn in self._key_to_pair_of_welfare_names(
+ for welfare_pairs_key in all_welfare_pairs_wt_payoffs.keys():
+ for welfare_fn in MetaGameSolver.key_to_pair_of_welfare_names(
welfare_pairs_key
):
all_welfare_fn.append(welfare_fn)
- self.all_welfare_fn = tuple(set(all_welfare_fn))
+ return sorted(tuple(set(all_welfare_fn)))
@staticmethod
- def _key_to_pair_of_welfare_names(key):
+ def key_to_pair_of_welfare_names(key):
+ """Convert a key to the two welfare functions composing this key"""
return key.split("-")
def _list_all_set_of_welfare_fn(self):
@@ -90,7 +97,7 @@ def _list_all_set_of_welfare_fn(self):
frozenset(combi) for combi in combinations_object
]
welfare_fn_sets.extend(combinations_set)
- self.welfare_fn_sets = tuple(set(welfare_fn_sets))
+ self.welfare_fn_sets = sorted(tuple(set(welfare_fn_sets)))
def solve_meta_game(self, tau):
"""
@@ -98,7 +105,10 @@ def solve_meta_game(self, tau):
:param tau:
"""
print("================================================")
- print(f"Start solving meta game with tau={tau}")
+ print(
+ f"Player (idx={self.own_player_idx}) starts solving meta game "
+ f"with tau={tau}"
+ )
self.tau = tau
self.selected_pure_policy_idx = None
self.best_objective = -np.inf
@@ -116,6 +126,7 @@ def solve_meta_game(self, tau):
]
def _search_for_best_action(self):
+ self.tau_log = []
for idx, welfare_set_annonced in enumerate(self.welfare_fn_sets):
optimization_objective = self._compute_optimization_objective(
self.tau, welfare_set_annonced
@@ -123,26 +134,40 @@ def _search_for_best_action(self):
self._keep_action_if_best(optimization_objective, idx)
def _compute_optimization_objective(self, tau, welfare_set_annonced):
- print(f"========")
- print(f"compute objective for set {welfare_set_annonced}")
self._compute_payoffs_for_every_opponent_action(welfare_set_annonced)
opp_best_response_idx = self._get_opp_best_response_idx()
+
+ exploitability_term = tau * self._get_own_payoff(opp_best_response_idx)
+ payoffs = self._get_all_possible_payoffs_excluding_provided(
+ # excluding_idx=opp_best_response_idx
+ )
+ robustness_term = (1 - tau) * sum(payoffs) / len(payoffs)
+
+ optimization_objective = exploitability_term + robustness_term
+
+ print(f"========")
+ print(f"compute objective for set {welfare_set_annonced}")
print(
f"opponent best response is {opp_best_response_idx} => "
f"{self.welfare_fn_sets[opp_best_response_idx]}"
)
- exploitability_term = tau * self._get_own_payoff(opp_best_response_idx)
print("exploitability_term", exploitability_term)
+ print("robustness_term", robustness_term)
+ print("optimization_objective", optimization_objective)
- payoffs = self._get_all_possible_payoffs_excluding_one(
- excluding_idx=opp_best_response_idx
+ self.tau_log.append(
+ {
+ "welfare_set_annonced": welfare_set_annonced,
+ "opp_best_response_idx": opp_best_response_idx,
+ "opp_best_response_set": self.welfare_fn_sets[
+ opp_best_response_idx
+ ],
+ "robustness_term": robustness_term,
+ "optimization_objective": optimization_objective,
+ }
)
- robustness_term = (1 - tau) * sum(payoffs) / len(payoffs)
- print("robustness_term", robustness_term)
- optimization_objective = exploitability_term + robustness_term
- print("optimization_objective", optimization_objective)
return optimization_objective
def _compute_payoffs_for_every_opponent_action(self, welfare_set_annonced):
@@ -158,18 +183,18 @@ def _compute_meta_payoff(self, own_welfare_set, opp_welfare_set):
welfare_fn_intersection = own_welfare_set & opp_welfare_set
if len(welfare_fn_intersection) == 0:
- return self._get_own_payoff_for_default_strategies()
+ return self._read_payoffs_for_default_strategies()
else:
- return self._get_own_payoff_averaging_over_welfare_intersection(
+ return self._read_payoffs_and_average_over_welfare_intersection(
welfare_fn_intersection
)
- def _get_own_payoff_for_default_strategies(self):
+ def _read_payoffs_for_default_strategies(self):
return self._read_own_payoff_from_data(
self.own_default_welfare_fn, self.opp_default_welfare_fn
)
- def _get_own_payoff_averaging_over_welfare_intersection(
+ def _read_payoffs_and_average_over_welfare_intersection(
self, welfare_fn_intersection
):
payoffs_player_1 = []
@@ -182,16 +207,17 @@ def _get_own_payoff_averaging_over_welfare_intersection(
payoffs_player_2.append(payoff_player_2)
mean_payoff_p1 = sum(payoffs_player_1) / len(payoffs_player_1)
mean_payoff_p2 = sum(payoffs_player_2) / len(payoffs_player_2)
- return (mean_payoff_p1, mean_payoff_p2)
+ return mean_payoff_p1, mean_payoff_p2
def _read_own_payoff_from_data(self, own_welfare, opp_welfare):
- welfare_pair_name = self._from_pair_of_welfare_names_to_key(
+ welfare_pair_name = self.from_pair_of_welfare_names_to_key(
own_welfare, opp_welfare
)
return self.all_welfare_pairs_wt_payoffs[welfare_pair_name]
@staticmethod
- def _from_pair_of_welfare_names_to_key(own_welfare_set, opp_welfare_set):
+ def from_pair_of_welfare_names_to_key(own_welfare_set, opp_welfare_set):
+ """Convert two welfare functions into a key identifying this pair"""
return f"{own_welfare_set}-{opp_welfare_set}"
def _get_opp_best_response_idx(self):
@@ -201,13 +227,15 @@ def _get_opp_best_response_idx(self):
]
return opp_payoffs.index(max(opp_payoffs))
- def _get_all_possible_payoffs_excluding_one(self, excluding_idx):
+ def _get_all_possible_payoffs_excluding_provided(self, excluding_idx=()):
own_payoffs = []
for welfare_set_idx, _ in enumerate(self.welfare_fn_sets):
- if welfare_set_idx != excluding_idx:
+ if welfare_set_idx not in excluding_idx:
own_payoff = self._get_own_payoff(welfare_set_idx)
own_payoffs.append(own_payoff)
- assert len(own_payoffs) == len(self.welfare_fn_sets) - 1
+ assert len(own_payoffs) == len(self.welfare_fn_sets) - len(
+ excluding_idx
+ )
return own_payoffs
def _get_own_payoff(self, idx):
@@ -253,14 +281,47 @@ def __init__(
**kwargs,
)
+ assert (
+ self._use_distribution_over_sets()
+ or self.config["solve_meta_game_after_init"]
+ )
if self.config["solve_meta_game_after_init"]:
+ assert not self.config["distrib_over_welfare_sets_to_annonce"]
self._choose_which_welfare_set_to_annonce()
+ if self._use_distribution_over_sets():
+ self._init_distribution_over_sets_to_announce()
self._welfare_set_annonced = False
self._welfare_set_in_use = None
self._welfare_chosen_for_epi = False
self._intersection_of_welfare_sets = None
+ def _init_distribution_over_sets_to_announce(self):
+ assert not self.config["solve_meta_game_after_init"]
+ self.setup_meta_game(
+ all_welfare_pairs_wt_payoffs=self.config[
+ "all_welfare_pairs_wt_payoffs"
+ ],
+ own_player_idx=self.config["own_player_idx"],
+ opp_player_idx=self.config["opp_player_idx"],
+ own_default_welfare_fn=self.config["own_default_welfare_fn"],
+ opp_default_welfare_fn=self.config["opp_default_welfare_fn"],
+ )
+ self.stochastic_announcement_policy = torch.distributions.Categorical(
+ probs=self.config["distrib_over_welfare_sets_to_annonce"]
+ )
+ self._stochastic_selection_of_welfare_set_to_announce()
+
+ def _stochastic_selection_of_welfare_set_to_announce(self):
+ welfare_set_idx_selected = self.stochastic_announcement_policy.sample()
+ self.welfare_set_to_annonce = self.welfare_fn_sets[
+ welfare_set_idx_selected
+ ]
+ self._to_log["welfare_set_to_annonce"] = str(
+ self.welfare_set_to_annonce
+ )
+ print("self.welfare_set_to_annonce", self.welfare_set_to_annonce)
+
def _choose_which_welfare_set_to_annonce(self):
self.setup_meta_game(
all_welfare_pairs_wt_payoffs=self.config[
@@ -272,6 +333,9 @@ def _choose_which_welfare_set_to_annonce(self):
opp_default_welfare_fn=self.config["opp_default_welfare_fn"],
)
self.solve_meta_game(self.config["tau"])
+ self._to_log[f"tau_{self.tau}"] = self.tau_log[
+ self.selected_pure_policy_idx
+ ]
@property
def policy_checkpoints(self):
@@ -282,9 +346,9 @@ def policy_checkpoints(self):
@policy_checkpoints.setter
def policy_checkpoints(self, value):
- msg = f"ignoring set self.policy_checkpoints to value {value}"
- print(msg)
- logger.warning(msg)
+ logger.warning(
+ f"ignoring setting self.policy_checkpoints to value {value}"
+ )
@override(population.PopulationOfIdenticalAlgo)
def set_algo_to_use(self):
@@ -310,7 +374,16 @@ def on_episode_start(
if not self._welfare_chosen_for_epi:
# Called only by one agent, not by both
self._coordinate_on_welfare_to_use_for_epi(worker)
- self.set_algo_to_use()
+ super().on_episode_start(
+ *args,
+ worker,
+ policy,
+ policy_id,
+ policy_ids,
+ **kwargs,
+ )
+ assert self._welfare_set_annonced
+ assert self._welfare_chosen_for_epi
@staticmethod
def _annonce_welfare_sets(
@@ -357,19 +430,35 @@ def _coordinate_on_welfare_to_use_for_epi(worker):
if isinstance(policy, WelfareCoordinationTorchPolicy):
if len(policy._intersection_of_welfare_sets) > 0:
if welfare_to_play is None:
- print(
- "policy._welfare_set_in_use",
- policy._welfare_set_in_use,
- )
welfare_to_play = random.choice(
tuple(policy._welfare_set_in_use)
)
policy._welfare_in_use = welfare_to_play
+ msg = (
+ "Policies coordinated "
+ "to play for the next epi: "
+ f"_welfare_set_in_use={policy._welfare_set_in_use}"
+ f"and selected: policy._welfare_in_use={policy._welfare_in_use}"
+ )
+ logger.info(msg)
+ print(msg)
else:
policy._welfare_in_use = random.choice(
tuple(policy._welfare_set_in_use)
)
+ msg = (
+ "Policies did NOT coordinate "
+ "to play for the next epi: "
+ f"_welfare_set_in_use={policy._welfare_set_in_use}"
+ f"and selected: policy._welfare_in_use={policy._welfare_in_use}"
+ )
+ logger.info(msg)
+ print(msg)
policy._welfare_chosen_for_epi = True
+ policy._to_log["welfare_in_use"] = policy._welfare_in_use
+ policy._to_log[
+ "welfare_set_in_use"
+ ] = policy._welfare_set_in_use
def on_episode_end(
self,
@@ -377,4 +466,36 @@ def on_episode_end(
**kwargs,
):
self._welfare_chosen_for_epi = False
- self.algorithms[self.active_algo_idx].on_episode_end(*args, **kwargs)
+ if hasattr(self.algorithms[self.active_algo_idx], "on_episode_end"):
+ self.algorithms[self.active_algo_idx].on_episode_end(
+ *args, **kwargs
+ )
+ if self._use_distribution_over_sets():
+ self._stochastic_selection_of_welfare_set_to_announce()
+ self._welfare_set_annonced = False
+
+ def _use_distribution_over_sets(self):
+ return not self.config["distrib_over_welfare_sets_to_annonce"] is False
+
+ @property
+ @override(population.PopulationOfIdenticalAlgo)
+ def to_log(self):
+ to_log = {
+ "meta_policy": self._to_log,
+ "nested_policy": {
+ f"policy_{algo_idx}": algo.to_log
+ for algo_idx, algo in enumerate(self.algorithms)
+ if hasattr(algo, "to_log")
+ },
+ }
+ return to_log
+
+ @to_log.setter
+ @override(population.PopulationOfIdenticalAlgo)
+ def to_log(self, value):
+ if value == {}:
+ for algo in self.algorithms:
+ if hasattr(algo, "to_log"):
+ algo.to_log = {}
+
+ self._to_log = value
diff --git a/marltoolbox/envs/coin_game.py b/marltoolbox/envs/coin_game.py
index 8c56f8f..4561c55 100644
--- a/marltoolbox/envs/coin_game.py
+++ b/marltoolbox/envs/coin_game.py
@@ -14,17 +14,22 @@
logger = logging.getLogger(__name__)
-PLOT_KEYS = ["pick_speed",
- "pick_own_color",
- ]
+PLOT_KEYS = [
+ "pick_speed",
+ "pick_own_color",
+]
PLOT_ASSEMBLAGE_TAGS = [
("pick_own_color_player_red_mean", "pick_own_color_player_blue_mean"),
("pick_speed_player_red_mean", "pick_speed_player_blue_mean"),
("pick_own_color",),
("pick_speed", "pick_own_color"),
- ("pick_own_color_player_red_mean", "pick_own_color_player_blue_mean",
- "pick_speed_player_red_mean", "pick_speed_player_blue_mean"),
+ (
+ "pick_own_color_player_red_mean",
+ "pick_own_color_player_blue_mean",
+ "pick_speed_player_red_mean",
+ "pick_speed_player_blue_mean",
+ ),
]
@@ -32,6 +37,7 @@ class CoinGame(InfoAccumulationInterface, MultiAgentEnv, gym.Env):
"""
Coin Game environment.
"""
+
NAME = "CoinGame"
NUM_AGENTS = 2
NUM_ACTIONS = 4
@@ -55,7 +61,7 @@ def __init__(self, config: Dict = {}):
low=0,
high=1,
shape=(self.grid_size, self.grid_size, 4),
- dtype="uint8"
+ dtype="uint8",
)
self.step_count_in_current_episode = None
@@ -69,17 +75,22 @@ def _validate_config(self, config):
assert len(config["players_ids"]) == self.NUM_AGENTS
def _load_config(self, config):
- self.players_ids = \
- config.get("players_ids", ["player_red", "player_blue"])
+ self.players_ids = config.get(
+ "players_ids", ["player_red", "player_blue"]
+ )
self.max_steps = config.get("max_steps", 20)
self.grid_size = config.get("grid_size", 3)
- self.output_additional_info = config.get("output_additional_info",
- True)
+ self.output_additional_info = config.get(
+ "output_additional_info", True
+ )
self.asymmetric = config.get("asymmetric", False)
- self.both_players_can_pick_the_same_coin = \
- config.get("both_players_can_pick_the_same_coin", True)
- self.same_obs_for_each_player = \
- config.get("same_obs_for_each_player", True)
+ self.both_players_can_pick_the_same_coin = config.get(
+ "both_players_can_pick_the_same_coin", True
+ )
+ self.same_obs_for_each_player = config.get(
+ "same_obs_for_each_player", True
+ )
+ self.punishment_helped = config.get("punishment_helped", False)
@override(gym.Env)
def seed(self, seed=None):
@@ -89,6 +100,7 @@ def seed(self, seed=None):
@override(gym.Env)
def reset(self):
+ # print("reset")
self.step_count_in_current_episode = 0
if self.output_additional_info:
@@ -98,18 +110,17 @@ def reset(self):
self._generate_coin()
obs = self._generate_observation()
- return {
- self.player_red_id: obs[0],
- self.player_blue_id: obs[1]
- }
+ return {self.player_red_id: obs[0], self.player_blue_id: obs[1]}
def _randomize_color_and_player_positions(self):
# Reset coin color and the players and coin positions
self.red_coin = self.np_random.randint(low=0, high=2)
- self.red_pos = \
- self.np_random.randint(low=0, high=self.grid_size, size=(2,))
- self.blue_pos = \
- self.np_random.randint(low=0, high=self.grid_size, size=(2,))
+ self.red_pos = self.np_random.randint(
+ low=0, high=self.grid_size, size=(2,)
+ )
+ self.blue_pos = self.np_random.randint(
+ low=0, high=self.grid_size, size=(2,)
+ )
# self.coin_pos = np.zeros(shape=(2,), dtype=np.int8)
self._players_do_not_overlap_at_start()
@@ -165,51 +176,62 @@ def _same_pos(self, x, y):
return (x == y).all()
def _move_players(self, actions):
- self.red_pos = \
- (self.red_pos + self.MOVES[actions[0]]) % self.grid_size
- self.blue_pos = \
- (self.blue_pos + self.MOVES[actions[1]]) % self.grid_size
+ self.red_pos = (self.red_pos + self.MOVES[actions[0]]) % self.grid_size
+ self.blue_pos = (
+ self.blue_pos + self.MOVES[actions[1]]
+ ) % self.grid_size
def _compute_reward(self):
reward_red = 0.0
reward_blue = 0.0
generate_new_coin = False
- red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = \
- False, False, False, False
+ red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = (
+ False,
+ False,
+ False,
+ False,
+ )
red_first_if_both = None
if not self.both_players_can_pick_the_same_coin:
- if self._same_pos(self.red_pos, self.coin_pos) and \
- self._same_pos(self.blue_pos, self.coin_pos):
+ if self._same_pos(self.red_pos, self.coin_pos) and self._same_pos(
+ self.blue_pos, self.coin_pos
+ ):
red_first_if_both = bool(self.np_random.randint(low=0, high=2))
if self.red_coin:
- if self._same_pos(self.red_pos, self.coin_pos) and \
- (red_first_if_both is None or red_first_if_both):
+ if self._same_pos(self.red_pos, self.coin_pos) and (
+ red_first_if_both is None or red_first_if_both
+ ):
generate_new_coin = True
reward_red += 1
if self.asymmetric:
reward_red += 3
red_pick_any = True
red_pick_red = True
- if self._same_pos(self.blue_pos, self.coin_pos) and \
- (red_first_if_both is None or not red_first_if_both):
+ if self._same_pos(self.blue_pos, self.coin_pos) and (
+ red_first_if_both is None or not red_first_if_both
+ ):
generate_new_coin = True
reward_red += -2
reward_blue += 1
blue_pick_any = True
+ if self.asymmetric and self.punishment_helped:
+ reward_red -= 6
else:
- if self._same_pos(self.red_pos, self.coin_pos) and \
- (red_first_if_both is None or red_first_if_both):
+ if self._same_pos(self.red_pos, self.coin_pos) and (
+ red_first_if_both is None or red_first_if_both
+ ):
generate_new_coin = True
reward_red += 1
reward_blue += -2
if self.asymmetric:
reward_red += 3
red_pick_any = True
- if self._same_pos(self.blue_pos, self.coin_pos) and \
- (red_first_if_both is None or not red_first_if_both):
+ if self._same_pos(self.blue_pos, self.coin_pos) and (
+ red_first_if_both is None or not red_first_if_both
+ ):
generate_new_coin = True
reward_blue += 1
blue_pick_blue = True
@@ -217,10 +239,12 @@ def _compute_reward(self):
reward_list = [reward_red, reward_blue]
if self.output_additional_info:
- self._accumulate_info(red_pick_any=red_pick_any,
- red_pick_red=red_pick_red,
- blue_pick_any=blue_pick_any,
- blue_pick_blue=blue_pick_blue)
+ self._accumulate_info(
+ red_pick_any=red_pick_any,
+ red_pick_red=red_pick_red,
+ blue_pick_any=blue_pick_any,
+ blue_pick_blue=blue_pick_blue,
+ )
return reward_list, generate_new_coin
@@ -260,11 +284,12 @@ def _to_RLLib_API(self, observations, rewards):
self.player_blue_id: rewards[1],
}
- epi_is_done = (self.step_count_in_current_episode >= self.max_steps)
+ epi_is_done = self.step_count_in_current_episode >= self.max_steps
if self.step_count_in_current_episode > self.max_steps:
logger.warning(
"step_count_in_current_episode > self.max_steps: "
- f"{self.step_count_in_current_episode} > {self.max_steps}")
+ f"{self.step_count_in_current_episode} > {self.max_steps}"
+ )
done = {
self.player_red_id: epi_is_done,
@@ -283,7 +308,7 @@ def _to_RLLib_API(self, observations, rewards):
return obs, rewards, done, info
@override(InfoAccumulationInterface)
- def _get_episode_info(self):
+ def _get_episode_info(self, n_steps_played=None):
"""
Output the following information:
pick_speed is the fraction of steps during which the player picked a
@@ -292,20 +317,25 @@ def _get_episode_info(self):
the same color as the player.
"""
player_red_info, player_blue_info = {}, {}
+ if n_steps_played is None:
+ n_steps_played = len(self.red_pick)
+ assert len(self.red_pick) == len(self.blue_pick)
if len(self.red_pick) > 0:
red_pick = sum(self.red_pick)
- player_red_info["pick_speed"] = red_pick / len(self.red_pick)
+ player_red_info["pick_speed"] = red_pick / n_steps_played
if red_pick > 0:
- player_red_info["pick_own_color"] = \
+ player_red_info["pick_own_color"] = (
sum(self.red_pick_own) / red_pick
+ )
if len(self.blue_pick) > 0:
blue_pick = sum(self.blue_pick)
- player_blue_info["pick_speed"] = blue_pick / len(self.blue_pick)
+ player_blue_info["pick_speed"] = blue_pick / n_steps_played
if blue_pick > 0:
- player_blue_info["pick_own_color"] = \
+ player_blue_info["pick_own_color"] = (
sum(self.blue_pick_own) / blue_pick
+ )
return player_red_info, player_blue_info
@@ -318,7 +348,8 @@ def _reset_info(self):
@override(InfoAccumulationInterface)
def _accumulate_info(
- self, red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue):
+ self, red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue
+ ):
self.red_pick.append(red_pick_any)
self.red_pick_own.append(red_pick_red)
diff --git a/marltoolbox/envs/matrix_sequential_social_dilemma.py b/marltoolbox/envs/matrix_sequential_social_dilemma.py
index 52fd4f3..05385a4 100644
--- a/marltoolbox/envs/matrix_sequential_social_dilemma.py
+++ b/marltoolbox/envs/matrix_sequential_social_dilemma.py
@@ -43,7 +43,7 @@ class MatrixSequentialSocialDilemma(
"""
A multi-agent abstract class for two player matrix games.
- PAYOUT_MATRIX: Numpy array. Along the dimension N, the action of the
+ PAYOFF_MATRIX: Numpy array. Along the dimension N, the action of the
Nth player change. The last dimension is used to select the player
whose reward you want to know.
@@ -56,10 +56,22 @@ class MatrixSequentialSocialDilemma(
episode.
"""
+ NUM_AGENTS = 2
+ NUM_ACTIONS = None
+ NUM_STATES = None
+ ACTION_SPACE = None
+ OBSERVATION_SPACE = None
+ PAYOFF_MATRIX = None
+ NAME = None
+
def __init__(self, config: Dict = {}):
assert "reward_randomness" not in config.keys()
- assert self.PAYOUT_MATRIX is not None
+ assert self.PAYOFF_MATRIX is not None
+ assert self.PAYOFF_MATRIX.shape[0] == self.NUM_ACTIONS
+ assert self.PAYOFF_MATRIX.shape[1] == self.NUM_ACTIONS
+ assert self.PAYOFF_MATRIX.shape[2] == self.NUM_AGENTS
+ assert len(self.PAYOFF_MATRIX.shape) == 3
if "players_ids" in config:
assert (
isinstance(config["players_ids"], Iterable)
@@ -156,8 +168,8 @@ def _produce_observations_invariant_to_the_player_trained(
def _get_players_rewards(self, action_player_0: int, action_player_1: int):
return [
- self.PAYOUT_MATRIX[action_player_0][action_player_1][0],
- self.PAYOUT_MATRIX[action_player_0][action_player_1][1],
+ self.PAYOFF_MATRIX[action_player_0][action_player_1][0],
+ self.PAYOFF_MATRIX[action_player_0][action_player_1][1],
]
def _to_RLLib_API(
@@ -205,12 +217,11 @@ class IteratedMatchingPennies(
A two-agent environment for the Matching Pennies game.
"""
- NUM_AGENTS = 2
NUM_ACTIONS = 2
- NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1
+ NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1
ACTION_SPACE = Discrete(NUM_ACTIONS)
OBSERVATION_SPACE = Discrete(NUM_STATES)
- PAYOUT_MATRIX = np.array([[[+1, -1], [-1, +1]], [[-1, +1], [+1, -1]]])
+ PAYOFF_MATRIX = np.array([[[+1, -1], [-1, +1]], [[-1, +1], [+1, -1]]])
NAME = "IMP"
@@ -221,12 +232,11 @@ class IteratedPrisonersDilemma(
A two-agent environment for the Prisoner's Dilemma game.
"""
- NUM_AGENTS = 2
NUM_ACTIONS = 2
- NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1
+ NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1
ACTION_SPACE = Discrete(NUM_ACTIONS)
OBSERVATION_SPACE = Discrete(NUM_STATES)
- PAYOUT_MATRIX = np.array([[[-1, -1], [-3, +0]], [[+0, -3], [-2, -2]]])
+ PAYOFF_MATRIX = np.array([[[-1, -1], [-3, +0]], [[+0, -3], [-2, -2]]])
NAME = "IPD"
@@ -237,12 +247,11 @@ class IteratedAsymPrisonersDilemma(
A two-agent environment for the Asymmetric Prisoner's Dilemma game.
"""
- NUM_AGENTS = 2
NUM_ACTIONS = 2
- NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1
+ NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1
ACTION_SPACE = Discrete(NUM_ACTIONS)
OBSERVATION_SPACE = Discrete(NUM_STATES)
- PAYOUT_MATRIX = np.array([[[+0, -1], [-3, +0]], [[+0, -3], [-2, -2]]])
+ PAYOFF_MATRIX = np.array([[[+0, -1], [-3, +0]], [[+0, -3], [-2, -2]]])
NAME = "IPD"
@@ -253,12 +262,11 @@ class IteratedStagHunt(
A two-agent environment for the Stag Hunt game.
"""
- NUM_AGENTS = 2
NUM_ACTIONS = 2
- NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1
+ NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1
ACTION_SPACE = Discrete(NUM_ACTIONS)
OBSERVATION_SPACE = Discrete(NUM_STATES)
- PAYOUT_MATRIX = np.array([[[3, 3], [0, 2]], [[2, 0], [1, 1]]])
+ PAYOFF_MATRIX = np.array([[[3, 3], [0, 2]], [[2, 0], [1, 1]]])
NAME = "IteratedStagHunt"
@@ -269,12 +277,11 @@ class IteratedChicken(
A two-agent environment for the Chicken game.
"""
- NUM_AGENTS = 2
NUM_ACTIONS = 2
- NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1
+ NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1
ACTION_SPACE = Discrete(NUM_ACTIONS)
OBSERVATION_SPACE = Discrete(NUM_STATES)
- PAYOUT_MATRIX = np.array(
+ PAYOFF_MATRIX = np.array(
[[[+0, +0], [-1.0, +1.0]], [[+1, -1], [-10, -10]]]
)
NAME = "IteratedChicken"
@@ -287,12 +294,11 @@ class IteratedAsymChicken(
A two-agent environment for the Asymmetric Chicken game.
"""
- NUM_AGENTS = 2
NUM_ACTIONS = 2
- NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1
+ NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1
ACTION_SPACE = Discrete(NUM_ACTIONS)
OBSERVATION_SPACE = Discrete(NUM_STATES)
- PAYOUT_MATRIX = np.array(
+ PAYOFF_MATRIX = np.array(
[[[+2.0, +0], [-1.0, +1.0]], [[+2.5, -1], [-10, -10]]]
)
NAME = "AsymmetricIteratedChicken"
@@ -305,12 +311,11 @@ class IteratedBoS(
A two-agent environment for the BoS game.
"""
- NUM_AGENTS = 2
NUM_ACTIONS = 2
- NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1
+ NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1
ACTION_SPACE = Discrete(NUM_ACTIONS)
OBSERVATION_SPACE = Discrete(NUM_STATES)
- PAYOUT_MATRIX = np.array(
+ PAYOFF_MATRIX = np.array(
[[[+3.0, +2.0], [+0.0, +0.0]], [[+0.0, +0.0], [+2.0, +3.0]]]
)
NAME = "IteratedBoS"
@@ -323,31 +328,31 @@ class IteratedAsymBoS(
A two-agent environment for the BoS game.
"""
- NUM_AGENTS = 2
NUM_ACTIONS = 2
- NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1
+ NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1
ACTION_SPACE = Discrete(NUM_ACTIONS)
OBSERVATION_SPACE = Discrete(NUM_STATES)
- PAYOUT_MATRIX = np.array(
+ PAYOFF_MATRIX = np.array(
[[[+4.0, +1.0], [+0.0, +0.0]], [[+0.0, +0.0], [+2.0, +2.0]]]
)
- NAME = "AsymmetricIteratedBoS"
+ NAME = "IteratedAsymBoS"
def define_greed_fear_matrix_game(greed, fear):
class GreedFearGame(
TwoPlayersTwoActionsInfoMixin, MatrixSequentialSocialDilemma
):
- NUM_AGENTS = 2
NUM_ACTIONS = 2
- NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1
+ NUM_STATES = (
+ NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1
+ )
ACTION_SPACE = Discrete(NUM_ACTIONS)
OBSERVATION_SPACE = Discrete(NUM_STATES)
R = 3
P = 1
T = R + greed
S = P - fear
- PAYOUT_MATRIX = np.array([[[R, R], [S, T]], [[T, S], [P, P]]])
+ PAYOFF_MATRIX = np.array([[[R, R], [S, T]], [[T, S], [P, P]]])
NAME = "IteratedGreedFear"
def __str__(self):
@@ -363,12 +368,11 @@ class IteratedBoSAndPD(
A two-agent environment for the BOTS + PD game.
"""
- NUM_AGENTS = 2
NUM_ACTIONS = 3
- NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1
+ NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1
ACTION_SPACE = Discrete(NUM_ACTIONS)
OBSERVATION_SPACE = Discrete(NUM_STATES)
- PAYOUT_MATRIX = np.array(
+ PAYOFF_MATRIX = np.array(
[
[[3.5, +1], [+0, +0], [-3, +2]],
[[+0.0, +0], [+1, +3], [-3, +2]],
@@ -376,3 +380,28 @@ class IteratedBoSAndPD(
]
)
NAME = "IteratedBoSAndPD"
+
+
+class TwoPlayersCustomizableMatrixGame(
+ NPlayersNDiscreteActionsInfoMixin, MatrixSequentialSocialDilemma
+):
+ """
+ A two-agent environment for the BOTS + PD game.
+ """
+
+ NAME = "TwoPlayersCustomizableMatrixGame"
+
+ NUM_ACTIONS = None
+ NUM_STATES = None
+ ACTION_SPACE = None
+ OBSERVATION_SPACE = None
+ PAYOFF_MATRIX = None
+
+ def __init__(self, config: Dict):
+ self.PAYOFF_MATRIX = config["PAYOFF_MATRIX"]
+ self.NUM_ACTIONS = config["NUM_ACTIONS"]
+ self.ACTION_SPACE = Discrete(self.NUM_ACTIONS)
+ self.NUM_STATES = self.NUM_ACTIONS ** self.NUM_AGENTS + 1
+ self.OBSERVATION_SPACE = Discrete(self.NUM_STATES)
+
+ super().__init__(config)
diff --git a/marltoolbox/envs/ssd_mixed_motive_coin_game.py b/marltoolbox/envs/ssd_mixed_motive_coin_game.py
index 0af66f7..03dbf84 100644
--- a/marltoolbox/envs/ssd_mixed_motive_coin_game.py
+++ b/marltoolbox/envs/ssd_mixed_motive_coin_game.py
@@ -40,6 +40,7 @@ class SSDMixedMotiveCoinGame(coin_game.CoinGame):
@override(coin_game.CoinGame)
def __init__(self, config: dict = {}):
super().__init__(config)
+ self.punishment_helped = config.get("punishment_helped", False)
self.OBSERVATION_SPACE = gym.spaces.Box(
low=0,
@@ -137,11 +138,11 @@ def _compute_reward(self):
if self._same_pos(self.red_pos, self.red_coin_pos):
if self.red_coin and self._same_pos(
- self.blue_pos, self.red_coin_pos
+ self.blue_pos, self.red_coin_pos
):
# Red coin is a coop coin
generate_new_coin = True
- reward_red += 1.2
+ reward_red += 2.0
red_pick_any = True
red_pick_red = True
blue_pick_any = True
@@ -152,13 +153,18 @@ def _compute_reward(self):
reward_red += 1.0
red_pick_any = True
red_pick_red = True
+ if self.punishment_helped and self._same_pos(
+ self.blue_pos, self.red_coin_pos
+ ):
+ reward_red -= 0.75
+
elif self._same_pos(self.blue_pos, self.blue_coin_pos):
if not self.red_coin and self._same_pos(
- self.red_pos, self.blue_coin_pos
+ self.red_pos, self.blue_coin_pos
):
# Blue coin is a coop coin
generate_new_coin = True
- reward_blue += 2.0
+ reward_blue += 3.0
red_pick_any = True
blue_pick_any = True
blue_pick_blue = True
@@ -169,6 +175,10 @@ def _compute_reward(self):
reward_blue += 1.0
blue_pick_any = True
blue_pick_blue = True
+ if self.punishment_helped and self._same_pos(
+ self.red_pos, self.blue_coin_pos
+ ):
+ reward_red -= 0.75
reward_list = [reward_red, reward_blue]
if self.output_additional_info:
@@ -185,37 +195,29 @@ def _compute_reward(self):
@override(coin_game.CoinGame)
def _init_info(self):
- self.red_pick = []
- self.red_pick_own = []
- self.blue_pick = []
- self.blue_pick_own = []
+ super()._init_info()
self.picked_red_coop = []
self.picked_blue_coop = []
@override(coin_game.CoinGame)
def _reset_info(self):
- self.red_pick.clear()
- self.red_pick_own.clear()
- self.blue_pick.clear()
- self.blue_pick_own.clear()
+ super()._reset_info()
self.picked_red_coop.clear()
self.picked_blue_coop.clear()
@override(coin_game.CoinGame)
def _accumulate_info(
- self,
- red_pick_any,
- red_pick_red,
- blue_pick_any,
- blue_pick_blue,
- picked_red_coop,
- picked_blue_coop,
+ self,
+ red_pick_any,
+ red_pick_red,
+ blue_pick_any,
+ blue_pick_blue,
+ picked_red_coop,
+ picked_blue_coop,
):
-
- self.red_pick.append(red_pick_any)
- self.red_pick_own.append(red_pick_red)
- self.blue_pick.append(blue_pick_any)
- self.blue_pick_own.append(blue_pick_blue)
+ super()._accumulate_info(
+ red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue
+ )
self.picked_red_coop.append(picked_red_coop)
self.picked_blue_coop.append(picked_blue_coop)
@@ -228,38 +230,25 @@ def _get_episode_info(self):
pick_own_color is the fraction of coins picked by the player which have
the same color as the player.
"""
- player_red_info, player_blue_info = {}, {}
+ player_red_info, player_blue_info = super()._get_episode_info()
+
n_steps_played = len(self.red_pick)
assert n_steps_played == len(self.blue_pick)
n_coop = sum(self.picked_blue_coop) + sum(self.picked_red_coop)
if len(self.red_pick) > 0:
red_pick = sum(self.red_pick)
- player_red_info["pick_speed"] = red_pick / n_steps_played
- if red_pick > 0:
- player_red_info["pick_own_color"] = (
- sum(self.red_pick_own) / red_pick
- )
-
player_red_info["red_coop_speed"] = (
- sum(self.picked_red_coop) / n_steps_played
+ sum(self.picked_red_coop) / n_steps_played
)
-
if red_pick > 0:
player_red_info["red_coop_fraction"] = n_coop / red_pick
if len(self.blue_pick) > 0:
blue_pick = sum(self.blue_pick)
- player_blue_info["pick_speed"] = blue_pick / n_steps_played
- if blue_pick > 0:
- player_blue_info["pick_own_color"] = (
- sum(self.blue_pick_own) / blue_pick
- )
-
player_blue_info["blue_coop_speed"] = (
- sum(self.picked_blue_coop) / n_steps_played
+ sum(self.picked_blue_coop) / n_steps_played
)
-
if blue_pick > 0:
player_blue_info["blue_coop_fraction"] = n_coop / blue_pick
diff --git a/marltoolbox/envs/vectorized_coin_game.py b/marltoolbox/envs/vectorized_coin_game.py
index 5f9f4f6..0fed500 100644
--- a/marltoolbox/envs/vectorized_coin_game.py
+++ b/marltoolbox/envs/vectorized_coin_game.py
@@ -28,17 +28,21 @@ def __init__(self, config={}):
self.batch_size = config.get("batch_size", 1)
self.force_vectorized = config.get("force_vectorize", False)
- assert self.grid_size == 3, \
- "hardcoded in the generate_state numba function"
+ self.punishment_helped = config.get("punishment_helped", False)
+ assert (
+ self.grid_size == 3
+ ), "hardcoded in the generate_state numba function"
@override(coin_game.CoinGame)
def _randomize_color_and_player_positions(self):
# Reset coin color and the players and coin positions
- self.red_coin = np.random.randint(2, size=self.batch_size)
+ self.red_coin = np.random.randint(low=0, high=2, size=self.batch_size)
self.red_pos = np.random.randint(
- self.grid_size, size=(self.batch_size, 2))
+ self.grid_size, size=(self.batch_size, 2)
+ )
self.blue_pos = np.random.randint(
- self.grid_size, size=(self.batch_size, 2))
+ self.grid_size, size=(self.batch_size, 2)
+ )
self.coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8)
self._players_do_not_overlap_at_start()
@@ -52,18 +56,30 @@ def _players_do_not_overlap_at_start(self):
@override(coin_game.CoinGame)
def _generate_coin(self):
generate = np.ones(self.batch_size, dtype=bool)
- self.coin_pos = generate_coin_wt_numba_optimization(
- self.batch_size, generate, self.red_coin, self.red_pos,
- self.blue_pos, self.coin_pos, self.grid_size)
+ self.coin_pos, self.red_coin = generate_coin_wt_numba_optimization(
+ self.batch_size,
+ generate,
+ self.red_coin,
+ self.red_pos,
+ self.blue_pos,
+ self.coin_pos,
+ self.grid_size,
+ )
@override(coin_game.CoinGame)
def _generate_observation(self):
obs = generate_observations_wt_numba_optimization(
- self.batch_size, self.red_pos, self.blue_pos, self.coin_pos,
- self.red_coin, self.grid_size)
+ self.batch_size,
+ self.red_pos,
+ self.blue_pos,
+ self.coin_pos,
+ self.red_coin,
+ self.grid_size,
+ )
obs = self._apply_optional_invariance_to_the_player_trained(obs)
obs, _ = self._optional_unvectorize(obs)
+ # print("env nonzero obs", np.nonzero(obs[0]))
return obs
def _optional_unvectorize(self, obs, rewards=None):
@@ -75,49 +91,65 @@ def _optional_unvectorize(self, obs, rewards=None):
@override(coin_game.CoinGame)
def step(self, actions: Iterable):
-
+ # print("step")
+ # print("env self.red_coin", self.red_coin)
+ # print("env self.red_pos", self.red_pos)
+ # print("env self.blue_pos", self.blue_pos)
+ # print("env self.coin_pos", self.coin_pos)
actions = self._from_RLLib_API_to_list(actions)
self.step_count_in_current_episode += 1
- (self.red_pos, self.blue_pos, rewards, self.coin_pos, observation,
- self.red_coin, red_pick_any, red_pick_red, blue_pick_any,
- blue_pick_blue) = vectorized_step_wt_numba_optimization(
- actions, self.batch_size, self.red_pos, self.blue_pos,
- self.coin_pos, self.red_coin, self.grid_size, self.asymmetric,
- self.max_steps, self.both_players_can_pick_the_same_coin)
+ (
+ self.red_pos,
+ self.blue_pos,
+ rewards,
+ self.coin_pos,
+ observation,
+ self.red_coin,
+ red_pick_any,
+ red_pick_red,
+ blue_pick_any,
+ blue_pick_blue,
+ ) = vectorized_step_wt_numba_optimization(
+ actions,
+ self.batch_size,
+ self.red_pos,
+ self.blue_pos,
+ self.coin_pos,
+ self.red_coin,
+ self.grid_size,
+ self.asymmetric,
+ self.max_steps,
+ self.both_players_can_pick_the_same_coin,
+ self.punishment_helped,
+ )
if self.output_additional_info:
self._accumulate_info(
- red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue)
+ red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue
+ )
obs = self._apply_optional_invariance_to_the_player_trained(
- observation)
+ observation
+ )
obs, rewards = self._optional_unvectorize(obs, rewards)
-
+ # print("env actions", actions)
+ # print("env rewards", rewards)
+ # print("env self.red_coin", self.red_coin)
+ # print("env self.red_pos", self.red_pos)
+ # print("env self.blue_pos", self.blue_pos)
+ # print("env self.coin_pos", self.coin_pos)
+ # print("env nonzero obs", np.nonzero(obs[0]))
return self._to_RLLib_API(obs, rewards)
@override(coin_game.CoinGame)
- def _get_episode_info(self):
-
- player_red_info, player_blue_info = {}, {}
-
- if len(self.red_pick) > 0:
- red_pick = sum(self.red_pick)
- player_red_info["pick_speed"] = \
- red_pick / (len(self.red_pick) * self.batch_size)
- if red_pick > 0:
- player_red_info["pick_own_color"] = \
- sum(self.red_pick_own) / red_pick
-
- if len(self.blue_pick) > 0:
- blue_pick = sum(self.blue_pick)
- player_blue_info["pick_speed"] = \
- blue_pick / (len(self.blue_pick) * self.batch_size)
- if blue_pick > 0:
- player_blue_info["pick_own_color"] = \
- sum(self.blue_pick_own) / blue_pick
-
- return player_red_info, player_blue_info
+ def _get_episode_info(self, n_steps_played=None):
+ n_steps_played = (
+ n_steps_played
+ if n_steps_played is not None
+ else len(self.red_pick) * self.batch_size
+ )
+ return super()._get_episode_info(n_steps_played=n_steps_played)
@override(coin_game.CoinGame)
def _from_RLLib_API_to_list(self, actions):
@@ -140,15 +172,13 @@ def _save_env(self):
"grid_size": self.grid_size,
"asymmetric": self.asymmetric,
"batch_size": self.batch_size,
- "step_count_in_current_episode":
- self.step_count_in_current_episode,
+ "step_count_in_current_episode": self.step_count_in_current_episode,
"max_steps": self.max_steps,
"red_pick": self.red_pick,
"red_pick_own": self.red_pick_own,
"blue_pick": self.blue_pick,
"blue_pick_own": self.blue_pick_own,
- "both_players_can_pick_the_same_coin":
- self.both_players_can_pick_the_same_coin,
+ "both_players_can_pick_the_same_coin": self.both_players_can_pick_the_same_coin,
}
return copy.deepcopy(env_save_state)
@@ -170,94 +200,149 @@ def __init__(self, config={}):
@jit(nopython=True)
def vectorized_step_wt_numba_optimization(
- actions, batch_size, red_pos, blue_pos, coin_pos, red_coin,
- grid_size: int, asymmetric: bool, max_steps: int,
- both_players_can_pick_the_same_coin: bool):
+ actions,
+ batch_size,
+ red_pos,
+ blue_pos,
+ coin_pos,
+ red_coin,
+ grid_size: int,
+ asymmetric: bool,
+ max_steps: int,
+ both_players_can_pick_the_same_coin: bool,
+ punishment_helped: bool,
+):
red_pos, blue_pos = move_players(
- batch_size, actions, red_pos, blue_pos, grid_size)
+ batch_size, actions, red_pos, blue_pos, grid_size
+ )
- reward, generate, red_pick_any, red_pick_red, \
- blue_pick_any, blue_pick_blue = compute_reward(
- batch_size, red_pos, blue_pos, coin_pos, red_coin,
- asymmetric, both_players_can_pick_the_same_coin)
+ (
+ reward,
+ generate,
+ red_pick_any,
+ red_pick_red,
+ blue_pick_any,
+ blue_pick_blue,
+ ) = compute_reward(
+ batch_size,
+ red_pos,
+ blue_pos,
+ coin_pos,
+ red_coin,
+ asymmetric,
+ both_players_can_pick_the_same_coin,
+ punishment_helped,
+ )
- coin_pos = generate_coin_wt_numba_optimization(
- batch_size, generate, red_coin, red_pos, blue_pos, coin_pos, grid_size)
+ coin_pos, red_coin = generate_coin_wt_numba_optimization(
+ batch_size, generate, red_coin, red_pos, blue_pos, coin_pos, grid_size
+ )
obs = generate_observations_wt_numba_optimization(
- batch_size, red_pos, blue_pos, coin_pos, red_coin, grid_size)
+ batch_size, red_pos, blue_pos, coin_pos, red_coin, grid_size
+ )
- return red_pos, blue_pos, reward, coin_pos, obs, red_coin, red_pick_any, \
- red_pick_red, blue_pick_any, blue_pick_blue
+ return (
+ red_pos,
+ blue_pos,
+ reward,
+ coin_pos,
+ obs,
+ red_coin,
+ red_pick_any,
+ red_pick_red,
+ blue_pick_any,
+ blue_pick_blue,
+ )
@jit(nopython=True)
def move_players(batch_size, actions, red_pos, blue_pos, grid_size):
- moves = List([
- np.array([0, 1]),
- np.array([0, -1]),
- np.array([1, 0]),
- np.array([-1, 0]),
- ])
+ moves = List(
+ [
+ np.array([0, 1]),
+ np.array([0, -1]),
+ np.array([1, 0]),
+ np.array([-1, 0]),
+ ]
+ )
for j in prange(batch_size):
- red_pos[j] = \
- (red_pos[j] + moves[actions[j, 0]]) % grid_size
- blue_pos[j] = \
- (blue_pos[j] + moves[actions[j, 1]]) % grid_size
+ red_pos[j] = (red_pos[j] + moves[actions[j, 0]]) % grid_size
+ blue_pos[j] = (blue_pos[j] + moves[actions[j, 1]]) % grid_size
return red_pos, blue_pos
@jit(nopython=True)
-def compute_reward(batch_size, red_pos, blue_pos, coin_pos, red_coin,
- asymmetric, both_players_can_pick_the_same_coin):
+def compute_reward(
+ batch_size,
+ red_pos,
+ blue_pos,
+ coin_pos,
+ red_coin,
+ asymmetric,
+ both_players_can_pick_the_same_coin,
+ punishment_helped,
+):
reward_red = np.zeros(batch_size)
reward_blue = np.zeros(batch_size)
generate = np.zeros(batch_size, dtype=np.bool_)
- red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = \
- 0, 0, 0, 0
+ red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = 0, 0, 0, 0
for i in prange(batch_size):
red_first_if_both = None
if not both_players_can_pick_the_same_coin:
- if _same_pos(red_pos[i], coin_pos[i]) and \
- _same_pos(blue_pos[i], coin_pos[i]):
+ if _same_pos(red_pos[i], coin_pos[i]) and _same_pos(
+ blue_pos[i], coin_pos[i]
+ ):
red_first_if_both = bool(np.random.randint(low=0, high=2))
if red_coin[i]:
- if _same_pos(red_pos[i], coin_pos[i]) and \
- (red_first_if_both is None or red_first_if_both):
+ if _same_pos(red_pos[i], coin_pos[i]) and (
+ red_first_if_both is None or red_first_if_both
+ ):
generate[i] = True
reward_red[i] += 1
if asymmetric:
reward_red[i] += 3
red_pick_any += 1
red_pick_red += 1
- if _same_pos(blue_pos[i], coin_pos[i]) and \
- (red_first_if_both is None or not red_first_if_both):
+ if _same_pos(blue_pos[i], coin_pos[i]) and (
+ red_first_if_both is None or not red_first_if_both
+ ):
generate[i] = True
reward_red[i] += -2
reward_blue[i] += 1
blue_pick_any += 1
+ if asymmetric and punishment_helped:
+ reward_red[i] -= 6
else:
- if _same_pos(red_pos[i], coin_pos[i]) and \
- (red_first_if_both is None or red_first_if_both):
+ if _same_pos(red_pos[i], coin_pos[i]) and (
+ red_first_if_both is None or red_first_if_both
+ ):
generate[i] = True
reward_red[i] += 1
reward_blue[i] += -2
if asymmetric:
reward_red[i] += 3
red_pick_any += 1
- if _same_pos(blue_pos[i], coin_pos[i]) and \
- (red_first_if_both is None or not red_first_if_both):
+ if _same_pos(blue_pos[i], coin_pos[i]) and (
+ red_first_if_both is None or not red_first_if_both
+ ):
generate[i] = True
reward_blue[i] += 1
blue_pick_any += 1
blue_pick_blue += 1
reward = [reward_red, reward_blue]
- return reward, generate, \
- red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue
+ return (
+ reward,
+ generate,
+ red_pick_any,
+ red_pick_red,
+ blue_pick_any,
+ blue_pick_blue,
+ )
@jit(nopython=True)
@@ -267,13 +352,13 @@ def _same_pos(x, y):
@jit(nopython=True)
def generate_coin_wt_numba_optimization(
- batch_size, generate, red_coin, red_pos, blue_pos, coin_pos,
- grid_size):
+ batch_size, generate, red_coin, red_pos, blue_pos, coin_pos, grid_size
+):
red_coin[generate] = 1 - red_coin[generate]
for i in prange(batch_size):
if generate[i]:
coin_pos[i] = _place_coin(red_pos[i], blue_pos[i], grid_size)
- return coin_pos
+ return coin_pos, red_coin
@jit(nopython=True)
@@ -303,8 +388,9 @@ def _unflatten_index(pos, grid_size):
@jit(nopython=True)
-def generate_observations_wt_numba_optimization(batch_size, red_pos, blue_pos,
- coin_pos, red_coin, grid_size):
+def generate_observations_wt_numba_optimization(
+ batch_size, red_pos, blue_pos, coin_pos, red_coin, grid_size
+):
obs = np.zeros((batch_size, grid_size, grid_size, 4))
for i in prange(batch_size):
obs[i, red_pos[i][0], red_pos[i][1], 0] = 1
diff --git a/marltoolbox/envs/vectorized_mixed_motive_coin_game.py b/marltoolbox/envs/vectorized_mixed_motive_coin_game.py
index ced01bd..a851f05 100644
--- a/marltoolbox/envs/vectorized_mixed_motive_coin_game.py
+++ b/marltoolbox/envs/vectorized_mixed_motive_coin_game.py
@@ -7,8 +7,12 @@
from ray.rllib.utils import override
from marltoolbox.envs import vectorized_coin_game
-from marltoolbox.envs.vectorized_coin_game import \
- _flatten_index, _unflatten_index, _same_pos, move_players
+from marltoolbox.envs.vectorized_coin_game import (
+ _flatten_index,
+ _unflatten_index,
+ _same_pos,
+ move_players,
+)
logger = logging.getLogger(__name__)
@@ -18,21 +22,23 @@
class VectMixedMotiveCG(vectorized_coin_game.VectorizedCoinGame):
-
@override(vectorized_coin_game.VectorizedCoinGame)
def _load_config(self, config):
super()._load_config(config)
- assert self.both_players_can_pick_the_same_coin, \
- "both_players_can_pick_the_same_coin option must be True in " \
+ assert self.both_players_can_pick_the_same_coin, (
+ "both_players_can_pick_the_same_coin option must be True in "
"mixed motive coin game."
+ )
@override(vectorized_coin_game.VectorizedCoinGame)
def _randomize_color_and_player_positions(self):
# Reset coin color and the players and coin positions
self.red_pos = np.random.randint(
- self.grid_size, size=(self.batch_size, 2))
+ self.grid_size, size=(self.batch_size, 2)
+ )
self.blue_pos = np.random.randint(
- self.grid_size, size=(self.batch_size, 2))
+ self.grid_size, size=(self.batch_size, 2)
+ )
self.red_coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8)
self.blue_coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8)
@@ -41,17 +47,29 @@ def _randomize_color_and_player_positions(self):
@override(vectorized_coin_game.VectorizedCoinGame)
def _generate_coin(self):
generate = np.ones(self.batch_size, dtype=np.int_) * 3
- self.red_coin_pos, self.blue_coin_pos = \
- generate_coin_wt_numba_optimization(
- self.batch_size, generate, self.red_coin_pos,
- self.blue_coin_pos, self.red_pos, self.blue_pos,
- self.grid_size)
+ (
+ self.red_coin_pos,
+ self.blue_coin_pos,
+ ) = generate_coin_wt_numba_optimization(
+ self.batch_size,
+ generate,
+ self.red_coin_pos,
+ self.blue_coin_pos,
+ self.red_pos,
+ self.blue_pos,
+ self.grid_size,
+ )
@override(vectorized_coin_game.VectorizedCoinGame)
def _generate_observation(self):
obs = generate_observations_wt_numba_optimization(
- self.batch_size, self.red_pos, self.blue_pos, self.red_coin_pos,
- self.blue_coin_pos, self.grid_size)
+ self.batch_size,
+ self.red_pos,
+ self.blue_pos,
+ self.red_coin_pos,
+ self.blue_coin_pos,
+ self.grid_size,
+ )
obs = self._apply_optional_invariance_to_the_player_trained(obs)
obs, _ = self._optional_unvectorize(obs)
@@ -62,19 +80,35 @@ def step(self, actions: Iterable):
actions = self._from_RLLib_API_to_list(actions)
self.step_count_in_current_episode += 1
- (self.red_pos, self.blue_pos, rewards, self.red_coin_pos,
- self.blue_coin_pos, observation, red_pick_any, red_pick_red,
- blue_pick_any, blue_pick_blue) = \
- vectorized_step_wt_numba_optimization(
- actions, self.batch_size, self.red_pos, self.blue_pos,
- self.red_coin_pos, self.blue_coin_pos, self.grid_size)
+ (
+ self.red_pos,
+ self.blue_pos,
+ rewards,
+ self.red_coin_pos,
+ self.blue_coin_pos,
+ observation,
+ red_pick_any,
+ red_pick_red,
+ blue_pick_any,
+ blue_pick_blue,
+ ) = vectorized_step_wt_numba_optimization(
+ actions,
+ self.batch_size,
+ self.red_pos,
+ self.blue_pos,
+ self.red_coin_pos,
+ self.blue_coin_pos,
+ self.grid_size,
+ )
if self.output_additional_info:
self._accumulate_info(
- red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue)
+ red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue
+ )
obs = self._apply_optional_invariance_to_the_player_trained(
- observation)
+ observation
+ )
obs, rewards = self._optional_unvectorize(obs, rewards)
return self._to_RLLib_API(obs, rewards)
@@ -88,8 +122,7 @@ def _save_env(self):
"blue_coin_pos": self.blue_coin_pos,
"grid_size": self.grid_size,
"batch_size": self.batch_size,
- "step_count_in_current_episode":
- self.step_count_in_current_episode,
+ "step_count_in_current_episode": self.step_count_in_current_episode,
"max_steps": self.max_steps,
"red_pick": self.red_pick,
"red_pick_own": self.red_pick_own,
@@ -101,24 +134,55 @@ def _save_env(self):
@jit(nopython=True)
def vectorized_step_wt_numba_optimization(
- actions, batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos,
- grid_size: int):
+ actions,
+ batch_size,
+ red_pos,
+ blue_pos,
+ red_coin_pos,
+ blue_coin_pos,
+ grid_size: int,
+):
red_pos, blue_pos = move_players(
- batch_size, actions, red_pos, blue_pos, grid_size)
+ batch_size, actions, red_pos, blue_pos, grid_size
+ )
- reward, generate, red_pick_any, red_pick_red, \
- blue_pick_any, blue_pick_blue = compute_reward(
- batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos)
+ (
+ reward,
+ generate,
+ red_pick_any,
+ red_pick_red,
+ blue_pick_any,
+ blue_pick_blue,
+ ) = compute_reward(
+ batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos
+ )
red_coin_pos, blue_coin_pos = generate_coin_wt_numba_optimization(
- batch_size, generate, red_coin_pos, blue_coin_pos,
- red_pos, blue_pos, grid_size)
+ batch_size,
+ generate,
+ red_coin_pos,
+ blue_coin_pos,
+ red_pos,
+ blue_pos,
+ grid_size,
+ )
obs = generate_observations_wt_numba_optimization(
- batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos, grid_size)
+ batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos, grid_size
+ )
- return red_pos, blue_pos, reward, red_coin_pos, blue_coin_pos, obs, \
- red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue
+ return (
+ red_pos,
+ blue_pos,
+ reward,
+ red_coin_pos,
+ blue_coin_pos,
+ obs,
+ red_pick_any,
+ red_pick_red,
+ blue_pick_any,
+ blue_pick_blue,
+ )
@jit(nopython=True)
@@ -126,20 +190,21 @@ def compute_reward(batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos):
reward_red = np.zeros(batch_size)
reward_blue = np.zeros(batch_size)
generate = np.zeros(batch_size, dtype=np.int_)
- red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = \
- 0, 0, 0, 0
+ red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = 0, 0, 0, 0
for i in prange(batch_size):
- if _same_pos(red_pos[i], red_coin_pos[i]) and \
- _same_pos(blue_pos[i], red_coin_pos[i]):
+ if _same_pos(red_pos[i], red_coin_pos[i]) and _same_pos(
+ blue_pos[i], red_coin_pos[i]
+ ):
generate[i] = 1
reward_red[i] += 2
reward_blue[i] += 2
red_pick_any += 1
red_pick_red += 1
blue_pick_any += 1
- elif _same_pos(red_pos[i], blue_coin_pos[i]) and \
- _same_pos(blue_pos[i], blue_coin_pos[i]):
+ elif _same_pos(red_pos[i], blue_coin_pos[i]) and _same_pos(
+ blue_pos[i], blue_coin_pos[i]
+ ):
generate[i] = 2
reward_red[i] += 1
reward_blue[i] += 4
@@ -149,14 +214,26 @@ def compute_reward(batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos):
reward = [reward_red, reward_blue]
- return reward, generate, \
- red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue
+ return (
+ reward,
+ generate,
+ red_pick_any,
+ red_pick_red,
+ blue_pick_any,
+ blue_pick_blue,
+ )
@jit(nopython=True)
def generate_coin_wt_numba_optimization(
- batch_size, generate, red_coin_pos, blue_coin_pos, red_pos, blue_pos,
- grid_size):
+ batch_size,
+ generate,
+ red_coin_pos,
+ blue_coin_pos,
+ red_pos,
+ blue_pos,
+ grid_size,
+):
for i in prange(batch_size):
# generate:0 => no coin generation
# generate:1 => red coin generation
@@ -164,11 +241,13 @@ def generate_coin_wt_numba_optimization(
# generate:0 => red & blue coin generation
if generate[i] == 3 or generate[i] == 1:
- red_coin_pos[i] = _place_coin(red_pos[i], blue_pos[i],
- grid_size, blue_coin_pos[i])
+ red_coin_pos[i] = _place_coin(
+ red_pos[i], blue_pos[i], grid_size, blue_coin_pos[i]
+ )
if generate[i] == 3 or generate[i] == 2:
- blue_coin_pos[i] = _place_coin(red_pos[i], blue_pos[i],
- grid_size, red_coin_pos[i])
+ blue_coin_pos[i] = _place_coin(
+ red_pos[i], blue_pos[i], grid_size, red_coin_pos[i]
+ )
return red_coin_pos, blue_coin_pos
@@ -178,19 +257,24 @@ def _place_coin(red_pos_i, blue_pos_i, grid_size, other_coin_pos_i):
blue_pos_flat = _flatten_index(blue_pos_i, grid_size)
other_coin_pos_flat = _flatten_index(other_coin_pos_i, grid_size)
possible_coin_pos = np.array(
- [x for x in range(9)
- if ((x != blue_pos_flat) and
- (x != red_pos_flat) and
- (x != other_coin_pos_flat))]
+ [
+ x
+ for x in range(9)
+ if (
+ (x != blue_pos_flat)
+ and (x != red_pos_flat)
+ and (x != other_coin_pos_flat)
+ )
+ ]
)
flat_coin_pos = np.random.choice(possible_coin_pos)
return _unflatten_index(flat_coin_pos, grid_size)
@jit(nopython=True)
-def generate_observations_wt_numba_optimization(batch_size, red_pos, blue_pos,
- red_coin_pos, blue_coin_pos,
- grid_size):
+def generate_observations_wt_numba_optimization(
+ batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos, grid_size
+):
obs = np.zeros((batch_size, grid_size, grid_size, 4))
for i in prange(batch_size):
obs[i, red_pos[i][0], red_pos[i][1], 0] = 1
diff --git a/marltoolbox/envs/vectorized_ssd_mm_coin_game.py b/marltoolbox/envs/vectorized_ssd_mm_coin_game.py
new file mode 100644
index 0000000..ef6ee2b
--- /dev/null
+++ b/marltoolbox/envs/vectorized_ssd_mm_coin_game.py
@@ -0,0 +1,428 @@
+import copy
+import logging
+from collections import Iterable
+
+import gym
+import numpy as np
+from numba import jit, prange
+from ray.rllib.utils import override
+
+from marltoolbox.envs import vectorized_coin_game
+from marltoolbox.envs.vectorized_coin_game import (
+ _same_pos,
+ move_players,
+)
+from marltoolbox.envs.vectorized_mixed_motive_coin_game import _place_coin
+
+logger = logging.getLogger(__name__)
+
+PLOT_KEYS = vectorized_coin_game.PLOT_KEYS
+
+PLOT_ASSEMBLAGE_TAGS = vectorized_coin_game.PLOT_ASSEMBLAGE_TAGS
+
+
+class VectSSDMixedMotiveCG(
+ vectorized_coin_game.VectorizedCoinGame,
+):
+ def __init__(self, config: dict = {}):
+ super().__init__(config)
+ self.punishment_helped = config.get("punishment_helped", False)
+
+ self.OBSERVATION_SPACE = gym.spaces.Box(
+ low=0,
+ high=1,
+ shape=(self.grid_size, self.grid_size, 6),
+ dtype="uint8",
+ )
+
+ @override(vectorized_coin_game.VectorizedCoinGame)
+ def _load_config(self, config):
+ super()._load_config(config)
+ assert self.both_players_can_pick_the_same_coin, (
+ "both_players_can_pick_the_same_coin option must be True in "
+ "ssd mixed motive coin game."
+ )
+ assert self.same_obs_for_each_player, (
+ "same_obs_for_each_player option must be True in "
+ "ssd mixed motive coin game."
+ )
+
+ @override(vectorized_coin_game.VectorizedCoinGame)
+ def _randomize_color_and_player_positions(self):
+ # Reset coin color and the players and coin positions
+ self.red_pos = np.random.randint(
+ self.grid_size, size=(self.batch_size, 2)
+ )
+ self.blue_pos = np.random.randint(
+ self.grid_size, size=(self.batch_size, 2)
+ )
+ self.red_coin = np.random.randint(
+ 2, size=self.batch_size, dtype=np.int8
+ )
+ self.red_coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8)
+ self.blue_coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8)
+
+ self._players_do_not_overlap_at_start()
+
+ @override(vectorized_coin_game.VectorizedCoinGame)
+ def _generate_coin(self):
+ self._generate_coin()
+
+ @override(vectorized_coin_game.VectorizedCoinGame)
+ def _generate_coin(self):
+ generate = np.ones(self.batch_size, dtype=np.int_)
+ (
+ self.red_coin_pos,
+ self.blue_coin_pos,
+ self.red_coin,
+ ) = generate_coin_wt_numba_optimization(
+ self.batch_size,
+ generate,
+ self.red_coin_pos,
+ self.blue_coin_pos,
+ self.red_pos,
+ self.blue_pos,
+ self.grid_size,
+ self.red_coin,
+ )
+
+ @override(vectorized_coin_game.VectorizedCoinGame)
+ def _generate_observation(self):
+ obs = generate_observations_wt_numba_optimization(
+ self.batch_size,
+ self.red_pos,
+ self.blue_pos,
+ self.red_coin_pos,
+ self.blue_coin_pos,
+ self.grid_size,
+ self.red_coin,
+ )
+
+ obs = self._apply_optional_invariance_to_the_player_trained(obs)
+ obs, _ = self._optional_unvectorize(obs)
+ return obs
+
+ @override(vectorized_coin_game.VectorizedCoinGame)
+ def step(self, actions: Iterable):
+ actions = self._from_RLLib_API_to_list(actions)
+ self.step_count_in_current_episode += 1
+
+ (
+ self.red_pos,
+ self.blue_pos,
+ rewards,
+ self.red_coin_pos,
+ self.blue_coin_pos,
+ observation,
+ red_pick_any,
+ red_pick_red,
+ blue_pick_any,
+ blue_pick_blue,
+ picked_red_coop,
+ picked_blue_coop,
+ self.red_coin,
+ ) = vectorized_step_wt_numba_optimization(
+ actions,
+ self.batch_size,
+ self.red_pos,
+ self.blue_pos,
+ self.red_coin_pos,
+ self.blue_coin_pos,
+ self.grid_size,
+ self.red_coin,
+ self.punishment_helped,
+ )
+
+ if self.output_additional_info:
+ self._accumulate_info(
+ red_pick_any,
+ red_pick_red,
+ blue_pick_any,
+ blue_pick_blue,
+ picked_red_coop,
+ picked_blue_coop,
+ )
+
+ obs = self._apply_optional_invariance_to_the_player_trained(
+ observation
+ )
+ obs, rewards = self._optional_unvectorize(obs, rewards)
+
+ return self._to_RLLib_API(obs, rewards)
+
+ @override(vectorized_coin_game.VectorizedCoinGame)
+ def _save_env(self):
+ env_save_state = {
+ "red_pos": self.red_pos,
+ "blue_pos": self.blue_pos,
+ "red_coin_pos": self.red_coin_pos,
+ "blue_coin_pos": self.blue_coin_pos,
+ "red_coin": self.red_coin,
+ "grid_size": self.grid_size,
+ "batch_size": self.batch_size,
+ "step_count_in_current_episode": self.step_count_in_current_episode,
+ "max_steps": self.max_steps,
+ "red_pick": self.red_pick,
+ "red_pick_own": self.red_pick_own,
+ "blue_pick": self.blue_pick,
+ "blue_pick_own": self.blue_pick_own,
+ }
+ return copy.deepcopy(env_save_state)
+
+ @override(vectorized_coin_game.VectorizedCoinGame)
+ def _init_info(self):
+ super()._init_info()
+ self.picked_red_coop = []
+ self.picked_blue_coop = []
+
+ @override(vectorized_coin_game.VectorizedCoinGame)
+ def _reset_info(self):
+ super()._reset_info()
+ self.picked_red_coop.clear()
+ self.picked_blue_coop.clear()
+
+ @override(vectorized_coin_game.VectorizedCoinGame)
+ def _accumulate_info(
+ self,
+ red_pick_any,
+ red_pick_red,
+ blue_pick_any,
+ blue_pick_blue,
+ picked_red_coop,
+ picked_blue_coop,
+ ):
+ super()._accumulate_info(
+ red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue
+ )
+ self.picked_red_coop.append(picked_red_coop)
+ self.picked_blue_coop.append(picked_blue_coop)
+
+ @override(vectorized_coin_game.VectorizedCoinGame)
+ def _get_episode_info(self):
+ """
+ Output the following information:
+ pick_speed is the fraction of steps during which the player picked a
+ coin.
+ pick_own_color is the fraction of coins picked by the player which have
+ the same color as the player.
+ """
+ n_steps_played = len(self.red_pick) * self.batch_size
+ player_red_info, player_blue_info = super()._get_episode_info(
+ n_steps_played=n_steps_played
+ )
+ n_coop = sum(self.picked_blue_coop) + sum(self.picked_red_coop)
+
+ if len(self.red_pick) > 0:
+ red_pick = sum(self.red_pick)
+ player_red_info["red_coop_speed"] = (
+ sum(self.picked_red_coop) / n_steps_played
+ )
+ if red_pick > 0:
+ player_red_info["red_coop_fraction"] = n_coop / red_pick
+
+ if len(self.blue_pick) > 0:
+ blue_pick = sum(self.blue_pick)
+ player_blue_info["blue_coop_speed"] = (
+ sum(self.picked_blue_coop) / n_steps_played
+ )
+ if blue_pick > 0:
+ player_blue_info["blue_coop_fraction"] = n_coop / blue_pick
+
+ return player_red_info, player_blue_info
+
+
+@jit(nopython=True)
+def vectorized_step_wt_numba_optimization(
+ actions,
+ batch_size,
+ red_pos,
+ blue_pos,
+ red_coin_pos,
+ blue_coin_pos,
+ grid_size: int,
+ red_coin,
+ punishment_helped,
+):
+ red_pos, blue_pos = move_players(
+ batch_size, actions, red_pos, blue_pos, grid_size
+ )
+
+ (
+ reward,
+ generate,
+ red_pick_any,
+ red_pick_red,
+ blue_pick_any,
+ blue_pick_blue,
+ picked_red_coop,
+ picked_blue_coop,
+ ) = compute_reward(
+ batch_size,
+ red_pos,
+ blue_pos,
+ red_coin_pos,
+ blue_coin_pos,
+ red_coin,
+ punishment_helped,
+ )
+
+ (
+ red_coin_pos,
+ blue_coin_pos,
+ red_coin,
+ ) = generate_coin_wt_numba_optimization(
+ batch_size,
+ generate,
+ red_coin_pos,
+ blue_coin_pos,
+ red_pos,
+ blue_pos,
+ grid_size,
+ red_coin,
+ )
+
+ obs = generate_observations_wt_numba_optimization(
+ batch_size,
+ red_pos,
+ blue_pos,
+ red_coin_pos,
+ blue_coin_pos,
+ grid_size,
+ red_coin,
+ )
+
+ return (
+ red_pos,
+ blue_pos,
+ reward,
+ red_coin_pos,
+ blue_coin_pos,
+ obs,
+ red_pick_any,
+ red_pick_red,
+ blue_pick_any,
+ blue_pick_blue,
+ picked_red_coop,
+ picked_blue_coop,
+ red_coin,
+ )
+
+
+@jit(nopython=True)
+def compute_reward(
+ batch_size,
+ red_pos,
+ blue_pos,
+ red_coin_pos,
+ blue_coin_pos,
+ red_coin,
+ punishment_helped,
+):
+ reward_red = np.zeros(batch_size)
+ reward_blue = np.zeros(batch_size)
+ generate = np.zeros(batch_size, dtype=np.int_)
+ red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = 0, 0, 0, 0
+ picked_red_coop, picked_blue_coop = 0, 0
+
+ for i in prange(batch_size):
+ if _same_pos(red_pos[i], red_coin_pos[i]):
+ if red_coin[i] and _same_pos(blue_pos[i], red_coin_pos[i]):
+ # Red coin is a coop coin
+ generate[i] = True
+ reward_red[i] += 2.0
+ red_pick_any += 1
+ red_pick_red += 1
+ blue_pick_any += 1
+ picked_red_coop += 1
+ elif not red_coin[i]:
+ # Red coin is a selfish coin
+ generate[i] = True
+ reward_red[i] += 1.0
+ red_pick_any += 1
+ red_pick_red += 1
+ if punishment_helped and _same_pos(
+ blue_pos[i], red_coin_pos[i]
+ ):
+ reward_red[i] -= 0.75
+ elif _same_pos(blue_pos[i], blue_coin_pos[i]):
+ if not red_coin[i] and _same_pos(red_pos[i], blue_coin_pos[i]):
+ # Blue coin is a coop coin
+ generate[i] = True
+ reward_blue[i] += 3.0
+ red_pick_any += 1
+ blue_pick_any += 1
+ blue_pick_blue += 1
+ picked_blue_coop += 1
+ elif red_coin[i]:
+ # Blue coin is a selfish coin
+ generate[i] = True
+ reward_blue[i] += 1.0
+ blue_pick_any += 1
+ blue_pick_blue += 1
+ if punishment_helped and _same_pos(
+ red_pos[i], blue_coin_pos[i]
+ ):
+ reward_blue[i] -= 0.75
+
+ reward = [reward_red, reward_blue]
+
+ return (
+ reward,
+ generate,
+ red_pick_any,
+ red_pick_red,
+ blue_pick_any,
+ blue_pick_blue,
+ picked_red_coop,
+ picked_blue_coop,
+ )
+
+
+@jit(nopython=True)
+def generate_observations_wt_numba_optimization(
+ batch_size,
+ red_pos,
+ blue_pos,
+ red_coin_pos,
+ blue_coin_pos,
+ grid_size,
+ red_coin,
+):
+ obs = np.zeros((batch_size, grid_size, grid_size, 6))
+ for i in prange(batch_size):
+ obs[i, red_pos[i][0], red_pos[i][1], 0] = 1
+ obs[i, blue_pos[i][0], blue_pos[i][1], 1] = 1
+ if red_coin[i]:
+ # Feature 4th is for the red cooperative coin
+ obs[i, red_coin_pos[i][0], red_coin_pos[i][1], 4] = 1
+ # Feature 3th is for the blue selfish coin
+ obs[i, blue_coin_pos[i][0], blue_coin_pos[i][1], 3] = 1
+ else:
+ # Feature 2th is for the red selfish coin
+ obs[i, red_coin_pos[i][0], red_coin_pos[i][1], 2] = 1
+ # Feature 5th is for the blue cooperative coin
+ obs[i, blue_coin_pos[i][0], blue_coin_pos[i][1], 5] = 1
+ return obs
+
+
+@jit(nopython=True)
+def generate_coin_wt_numba_optimization(
+ batch_size,
+ generate,
+ red_coin_pos,
+ blue_coin_pos,
+ red_pos,
+ blue_pos,
+ grid_size,
+ red_coin,
+):
+ for i in prange(batch_size):
+ if generate[i]:
+ red_coin[i] = 1 - red_coin[i]
+ red_coin_pos[i] = _place_coin(
+ red_pos[i], blue_pos[i], grid_size, blue_coin_pos[i]
+ )
+ blue_coin_pos[i] = _place_coin(
+ red_pos[i], blue_pos[i], grid_size, red_coin_pos[i]
+ )
+ return red_coin_pos, blue_coin_pos, red_coin
diff --git a/marltoolbox/examples/Tutorial_Basics_How_to_use_the_toolbox.ipynb b/marltoolbox/examples/Tutorial_Basics_How_to_use_the_toolbox.ipynb
index bb7bda6..c2a13e5 100644
--- a/marltoolbox/examples/Tutorial_Basics_How_to_use_the_toolbox.ipynb
+++ b/marltoolbox/examples/Tutorial_Basics_How_to_use_the_toolbox.ipynb
@@ -75,8 +75,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:58:39.129878Z",
- "start_time": "2021-04-14T09:58:39.126721Z"
+ "end_time": "2021-04-15T12:33:11.973283Z",
+ "start_time": "2021-04-15T12:33:11.970475Z"
},
"id": "DsrzADRv4Itm"
},
@@ -172,8 +172,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:58:41.814223Z",
- "start_time": "2021-04-14T09:58:39.133920Z"
+ "end_time": "2021-04-15T12:33:14.647926Z",
+ "start_time": "2021-04-15T12:33:11.976742Z"
},
"id": "BdtPEwKt-b47"
},
@@ -191,6 +191,7 @@
"\n",
"from marltoolbox.envs.matrix_sequential_social_dilemma import IteratedPrisonersDilemma\n",
"from marltoolbox.utils.miscellaneous import check_learning_achieved\n",
+ "from marltoolbox.utils import log\n",
"\n",
"from IPython.core.display import display, HTML\n",
"display(HTML(\"\"))"
@@ -224,8 +225,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:58:41.821285Z",
- "start_time": "2021-04-14T09:58:41.816146Z"
+ "end_time": "2021-04-15T12:33:14.653543Z",
+ "start_time": "2021-04-15T12:33:14.650069Z"
},
"id": "MHNIE5wp-nV8"
},
@@ -256,8 +257,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:58:41.907870Z",
- "start_time": "2021-04-14T09:58:41.824356Z"
+ "end_time": "2021-04-15T12:33:14.742628Z",
+ "start_time": "2021-04-15T12:33:14.655685Z"
},
"id": "nEC3Fbp9trEZ"
},
@@ -284,8 +285,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:58:42.005535Z",
- "start_time": "2021-04-14T09:58:41.910259Z"
+ "end_time": "2021-04-15T12:33:14.857873Z",
+ "start_time": "2021-04-15T12:33:14.745371Z"
},
"id": "6ZyWFTfy-KSd"
},
@@ -329,8 +330,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:58:42.124090Z",
- "start_time": "2021-04-14T09:58:42.007028Z"
+ "end_time": "2021-04-15T12:33:14.926104Z",
+ "start_time": "2021-04-15T12:33:14.859461Z"
},
"id": "DrBiGbnSBFns"
},
@@ -389,8 +390,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:58:42.211967Z",
- "start_time": "2021-04-14T09:58:42.127092Z"
+ "end_time": "2021-04-15T12:33:15.020084Z",
+ "start_time": "2021-04-15T12:33:14.928681Z"
},
"id": "tB0ig1cuCkl7"
},
@@ -480,8 +481,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:58:42.349958Z",
- "start_time": "2021-04-14T09:58:42.216488Z"
+ "end_time": "2021-04-15T12:33:15.114115Z",
+ "start_time": "2021-04-15T12:33:15.022186Z"
},
"id": "s4G0pHVJCpGU"
},
@@ -586,8 +587,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:58:53.956824Z",
- "start_time": "2021-04-14T09:58:42.352177Z"
+ "end_time": "2021-04-15T12:33:26.563552Z",
+ "start_time": "2021-04-15T12:33:15.119380Z"
},
"id": "ST5qb68DCMI7",
"scrolled": true
@@ -602,6 +603,7 @@
"\n",
"# stop the training after N Trainable.step (here this is equivalent to N episodes and N updates)\n",
"stop_config = {\"training_iteration\": 200} \n",
+ "exp_name, _ = log.log_in_current_day_dir(\"PG_IPD_notebook_exp\")\n",
"\n",
"# Restart Ray defensively in case the ray connection is lost.\n",
"ray.shutdown() \n",
@@ -611,7 +613,7 @@
" Trainable,\n",
" stop=stop_config,\n",
" config=tune_config,\n",
- " name=\"PG_IPD\",\n",
+ " name=exp_name,\n",
" )\n",
"\n",
"ray.shutdown()\n",
@@ -651,8 +653,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:59:06.311720Z",
- "start_time": "2021-04-14T09:58:53.963206Z"
+ "end_time": "2021-04-15T12:33:39.209663Z",
+ "start_time": "2021-04-15T12:33:26.565556Z"
},
"id": "vMbfhunq_ben"
},
@@ -669,6 +671,7 @@
"# but here varying \"training_iteration\" is interesting.\n",
"stop_config = {\"training_iteration\": tune.grid_search([2, 4, 8, 16, 32, 64, 128])} \n",
"#####\n",
+ "exp_name, _ = log.log_in_current_day_dir(\"PG_IPD_notebook_exp\")\n",
"\n",
"ray.shutdown() \n",
"ray.init(num_cpus=os.cpu_count(), num_gpus=0) \n",
@@ -676,7 +679,7 @@
" Trainable,\n",
" stop=stop_config,\n",
" config=tune_config,\n",
- " name=\"PG_IPD\",\n",
+ " name=exp_name,\n",
" )\n",
"ray.shutdown()\n",
"\n",
@@ -718,8 +721,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:00:01.360129Z",
- "start_time": "2021-04-14T10:00:01.355038Z"
+ "end_time": "2021-04-15T12:33:39.226528Z",
+ "start_time": "2021-04-15T12:33:39.211410Z"
},
"id": "9uHr3nuI-Lap"
},
@@ -780,8 +783,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:02:10.235678Z",
- "start_time": "2021-04-14T10:02:10.224537Z"
+ "end_time": "2021-04-15T12:33:39.355656Z",
+ "start_time": "2021-04-15T12:33:39.228620Z"
},
"id": "mbxWsThD-vEt"
},
@@ -855,8 +858,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:59:06.631534Z",
- "start_time": "2021-04-14T09:59:06.473496Z"
+ "end_time": "2021-04-15T12:33:39.486196Z",
+ "start_time": "2021-04-15T12:33:39.360619Z"
},
"id": "4-VPfgGP9-XH"
},
@@ -911,8 +914,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:59:06.714373Z",
- "start_time": "2021-04-14T09:59:06.632964Z"
+ "end_time": "2021-04-15T12:33:39.578807Z",
+ "start_time": "2021-04-15T12:33:39.487619Z"
},
"id": "gej3suhHAo0S"
},
@@ -984,8 +987,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:59:06.824496Z",
- "start_time": "2021-04-14T09:59:06.716292Z"
+ "end_time": "2021-04-15T12:33:39.713752Z",
+ "start_time": "2021-04-15T12:33:39.584035Z"
},
"id": "aJIHVPLjbfiC"
},
@@ -1047,8 +1050,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:59:06.916296Z",
- "start_time": "2021-04-14T09:59:06.826239Z"
+ "end_time": "2021-04-15T12:33:39.901063Z",
+ "start_time": "2021-04-15T12:33:39.715122Z"
},
"id": "qqlYU7ZiALxl"
},
@@ -1091,6 +1094,8 @@
" \"framework\": \"torch\",\n",
" \"batch_mode\": \"complete_episodes\",\n",
" # LTFT supports only 1 worker only otherwise it would be mixing several opponents trajectories\n",
+ " # Number of rollout worker actors to create for parallel sampling. Setting\n",
+ " # this to 0 will force rollouts to be done in the trainer actor.\n",
" \"num_workers\": 0,\n",
" # LTFT supports only 1 env per worker only otherwise several episodes would be played at the same time\n",
" \"num_envs_per_worker\": 1,\n",
@@ -1116,10 +1121,11 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:04:28.482320Z",
- "start_time": "2021-04-14T10:02:15.269072Z"
+ "end_time": "2021-04-15T12:35:51.624256Z",
+ "start_time": "2021-04-15T12:33:39.902523Z"
},
- "id": "bfEaO43v3UBB"
+ "id": "bfEaO43v3UBB",
+ "scrolled": true
},
"outputs": [],
"source": [
@@ -1139,14 +1145,14 @@
" \"debug\": False,\n",
"}\n",
"\n",
- "\n",
+ "exp_name, _ = log.log_in_current_day_dir(\"LTFT_notebook_exp\")\n",
"rllib_config = get_rllib_config(ltft_hparameters)\n",
"stop_config = get_stop_config(ltft_hparameters)\n",
"ray.shutdown()\n",
"ray.init(num_cpus=os.cpu_count(), num_gpus=0) \n",
"tune_analysis_self_play = ray.tune.run(ltft.LTFTTrainer, config=rllib_config,\n",
" checkpoint_freq=0, stop=stop_config, \n",
- " checkpoint_at_end=False, name=\"LTFT_exp\")\n",
+ " checkpoint_at_end=False, name=exp_name)\n",
"ray.shutdown()\n",
"\n",
"check_learning_achieved(tune_results=tune_analysis_self_play, \n",
@@ -1192,8 +1198,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:04:28.491013Z",
- "start_time": "2021-04-14T10:04:28.484650Z"
+ "end_time": "2021-04-15T12:35:51.630130Z",
+ "start_time": "2021-04-15T12:35:51.626154Z"
},
"id": "yCZzbx3Ld0vz"
},
@@ -1224,8 +1230,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:05:15.895877Z",
- "start_time": "2021-04-14T10:04:28.493528Z"
+ "end_time": "2021-04-15T12:36:53.910972Z",
+ "start_time": "2021-04-15T12:35:51.631535Z"
},
"id": "zETjK-Y5ZKLA"
},
@@ -1249,14 +1255,14 @@
" \"debug\": False,\n",
"}\n",
"\n",
- "\n",
+ "exp_name, _ = log.log_in_current_day_dir(\"LTFT_notebook_exp\")\n",
"rllib_config = get_rllib_config(ltft_hparameters)\n",
"stop_config = get_stop_config(ltft_hparameters)\n",
"ray.shutdown()\n",
"ray.init(num_cpus=os.cpu_count(), num_gpus=0) \n",
"tune_analysis_self_play = ray.tune.run(ltft.LTFTTrainer, config=rllib_config,\n",
" checkpoint_freq=0, stop=stop_config, \n",
- " checkpoint_at_end=False, name=\"LTFT_exp\")\n",
+ " checkpoint_at_end=False, name=exp_name)\n",
"ray.shutdown()\n",
"\n",
"check_learning_achieved(tune_results=tune_analysis_self_play,\n",
@@ -1299,8 +1305,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:59:10.306079Z",
- "start_time": "2021-04-14T09:58:39.178Z"
+ "end_time": "2021-04-15T12:36:53.923288Z",
+ "start_time": "2021-04-15T12:36:53.915782Z"
},
"id": "u00VE8oB4Itx"
},
@@ -1314,8 +1320,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T09:59:10.306786Z",
- "start_time": "2021-04-14T09:58:39.180Z"
+ "end_time": "2021-04-15T12:36:54.052499Z",
+ "start_time": "2021-04-15T12:36:53.925743Z"
},
"id": "NpeJySSk4Itx",
"scrolled": true
diff --git a/marltoolbox/examples/Tutorial_Evaluations_Level_1_best_response_and_self_play_and_cross_play.ipynb b/marltoolbox/examples/Tutorial_Evaluations_Level_1_best_response_and_self_play_and_cross_play.ipynb
index 5989126..8811697 100644
--- a/marltoolbox/examples/Tutorial_Evaluations_Level_1_best_response_and_self_play_and_cross_play.ipynb
+++ b/marltoolbox/examples/Tutorial_Evaluations_Level_1_best_response_and_self_play_and_cross_play.ipynb
@@ -32,8 +32,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:06:53.433046Z",
- "start_time": "2021-04-14T10:06:53.428041Z"
+ "end_time": "2021-04-15T12:55:08.510232Z",
+ "start_time": "2021-04-15T12:55:08.501235Z"
},
"id": "KMXC7aH5CD8p"
},
@@ -62,8 +62,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:06:53.517255Z",
- "start_time": "2021-04-14T10:06:53.435182Z"
+ "end_time": "2021-04-15T12:55:08.594695Z",
+ "start_time": "2021-04-15T12:55:08.512706Z"
},
"id": "DsrzADRv4Itm"
},
@@ -141,8 +141,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:06:56.606917Z",
- "start_time": "2021-04-14T10:06:53.518906Z"
+ "end_time": "2021-04-15T12:55:11.544021Z",
+ "start_time": "2021-04-15T12:55:08.597116Z"
},
"id": "BdtPEwKt-b47"
},
@@ -188,8 +188,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:06:56.614247Z",
- "start_time": "2021-04-14T10:06:56.609516Z"
+ "end_time": "2021-04-15T12:55:11.550532Z",
+ "start_time": "2021-04-15T12:55:11.546175Z"
},
"id": "9fZtrKHovNvV"
},
@@ -216,8 +216,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:06:56.722108Z",
- "start_time": "2021-04-14T10:06:56.616211Z"
+ "end_time": "2021-04-15T13:08:41.276000Z",
+ "start_time": "2021-04-15T13:08:41.268429Z"
},
"id": "MHNIE5wp-nV8"
},
@@ -229,10 +229,10 @@
"\n",
" # This modification to the policy will allow us to load each policies from different checkpoints \n",
" # This will be used during evaluation.\n",
- " def merged_after_init(*args, **kwargs):\n",
+ " def merged_before_loss_init(*args, **kwargs):\n",
" setup_mixins(*args, **kwargs)\n",
- " restore.after_init_load_policy_checkpoint(*args, **kwargs)\n",
- " MyPPOPolicy = PPOTorchPolicy.with_updates(after_init=merged_after_init)\n",
+ " restore.before_loss_init_load_policy_checkpoint(*args, **kwargs)\n",
+ " MyPPOPolicy = PPOTorchPolicy.with_updates(before_loss_init=merged_before_loss_init)\n",
"\n",
" stop_config = {\n",
" \"episodes_total\": hp[\"episodes_total\"],\n",
@@ -311,18 +311,19 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:07:47.492122Z",
- "start_time": "2021-04-14T10:06:56.726437Z"
+ "end_time": "2021-04-15T12:55:59.020889Z",
+ "start_time": "2021-04-15T12:55:11.652388Z"
},
"id": "XijF3Q7O0FxF"
},
"outputs": [],
"source": [
+ "exp_name, _ = log.log_in_current_day_dir(\"PPO_BoS_notebook_exp\")\n",
"hyperparameters = {\n",
" \"steps_per_epi\": 20,\n",
" \"train_n_replicates\": 8,\n",
" \"episodes_total\": 200,\n",
- " \"exp_name\": \"PPO_BoS\",\n",
+ " \"exp_name\": exp_name,\n",
" \"base_lr\": 5e-1,\n",
"}\n",
"\n",
@@ -350,8 +351,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:07:47.498013Z",
- "start_time": "2021-04-14T10:07:47.494191Z"
+ "end_time": "2021-04-15T12:55:59.045103Z",
+ "start_time": "2021-04-15T12:55:59.024866Z"
},
"id": "e4qziRvL0o_E"
},
@@ -375,8 +376,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:08:18.230859Z",
- "start_time": "2021-04-14T10:07:47.499394Z"
+ "end_time": "2021-04-15T13:11:25.364166Z",
+ "start_time": "2021-04-15T13:10:45.951397Z"
},
"id": "6ZyWFTfy-KSd"
},
@@ -455,8 +456,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:10:12.067609Z",
- "start_time": "2021-04-14T10:10:12.062918Z"
+ "end_time": "2021-04-15T12:56:29.454882Z",
+ "start_time": "2021-04-15T12:56:28.770969Z"
},
"id": "wmvRIbPqUAwy"
},
@@ -503,8 +504,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:08:19.676226Z",
- "start_time": "2021-04-14T10:08:19.671968Z"
+ "end_time": "2021-04-15T12:56:29.461390Z",
+ "start_time": "2021-04-15T12:56:29.456601Z"
},
"id": "bVqSBL72wHrd"
},
@@ -531,8 +532,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:08:19.788760Z",
- "start_time": "2021-04-14T10:08:19.678381Z"
+ "end_time": "2021-04-15T12:56:29.561534Z",
+ "start_time": "2021-04-15T12:56:29.463140Z"
},
"id": "vMbfhunq_ben"
},
@@ -540,11 +541,12 @@
"source": [
"def train_lvl0_agents(lvl0_hparameters):\n",
"\n",
+ " exp_name, _ = log.log_in_current_day_dir(\"Lvl0_LOLAExact_notebook_exp\")\n",
" tune_config = get_tune_config(lvl0_hparameters)\n",
" stop_config = get_stop_config(lvl0_hparameters)\n",
" ray.shutdown()\n",
" ray.init(num_cpus=os.cpu_count(), num_gpus=0) \n",
- " tune_analysis_lvl0 = tune.run(LOLAExact, name=\"Lvl0_LOLAExact\", config=tune_config,\n",
+ " tune_analysis_lvl0 = tune.run(LOLAExact, name=exp_name, config=tune_config,\n",
" checkpoint_at_end=True, stop=stop_config, \n",
" metric=lvl0_hparameters[\"metric\"], mode=\"max\")\n",
" ray.shutdown()\n",
@@ -583,8 +585,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:09:38.553623Z",
- "start_time": "2021-04-14T10:08:19.792698Z"
+ "end_time": "2021-04-15T12:57:43.481511Z",
+ "start_time": "2021-04-15T12:56:29.563639Z"
},
"id": "Im5nFVs7YUtU"
},
@@ -635,8 +637,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:11:56.338029Z",
- "start_time": "2021-04-14T10:11:56.328179Z"
+ "end_time": "2021-04-15T12:57:43.492585Z",
+ "start_time": "2021-04-15T12:57:43.483312Z"
},
"id": "P2loJhuzVho6"
},
@@ -644,6 +646,7 @@
"source": [
"def train_lvl1_agents(hp_lvl1, tune_analysis_lvl0):\n",
"\n",
+ " exp_name, _ = log.log_in_current_day_dir(\"Lvl1_PG_notebook_exp\")\n",
" rllib_config_lvl1, trainable_class, env_config = get_rllib_config(hp_lvl1)\n",
" stop_config = get_stop_config(hp_lvl1)\n",
" \n",
@@ -659,7 +662,7 @@
" stop=stop_config,\n",
" checkpoint_at_end=True,\n",
" metric=\"episode_reward_mean\", mode=\"max\",\n",
- " name=\"Lvl1_PG\")\n",
+ " name=exp_name)\n",
" ray.shutdown()\n",
" return tune_analysis_lvl1\n",
"\n",
@@ -746,8 +749,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:12:30.455824Z",
- "start_time": "2021-04-14T10:11:58.126061Z"
+ "end_time": "2021-04-15T12:58:19.055673Z",
+ "start_time": "2021-04-15T12:57:43.493968Z"
},
"id": "b2kmotZMYZBX",
"scrolled": true
@@ -785,8 +788,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:12:34.152484Z",
- "start_time": "2021-04-14T10:12:33.899223Z"
+ "end_time": "2021-04-15T12:58:19.100252Z",
+ "start_time": "2021-04-15T12:58:19.056994Z"
}
},
"outputs": [],
@@ -799,8 +802,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:12:36.039089Z",
- "start_time": "2021-04-14T10:12:35.975270Z"
+ "end_time": "2021-04-15T12:58:19.256941Z",
+ "start_time": "2021-04-15T12:58:19.101888Z"
},
"id": "M1lSyEf1D08q"
},
@@ -828,8 +831,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:13:04.766311Z",
- "start_time": "2021-04-14T10:13:04.716165Z"
+ "end_time": "2021-04-15T12:58:19.317574Z",
+ "start_time": "2021-04-15T12:58:19.258240Z"
},
"id": "v67qWvcDD08q"
},
@@ -871,8 +874,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:09:39.229687Z",
- "start_time": "2021-04-14T10:06:53.463Z"
+ "end_time": "2021-04-15T12:58:19.401613Z",
+ "start_time": "2021-04-15T12:58:19.319051Z"
},
"id": "u00VE8oB4Itx"
},
@@ -886,8 +889,8 @@
"execution_count": null,
"metadata": {
"ExecuteTime": {
- "end_time": "2021-04-14T10:09:39.230757Z",
- "start_time": "2021-04-14T10:06:53.465Z"
+ "end_time": "2021-04-15T12:58:19.501562Z",
+ "start_time": "2021-04-15T12:58:19.403275Z"
},
"id": "NpeJySSk4Itx",
"scrolled": true
diff --git a/marltoolbox/examples/rllib_api/dqn_coin_game.py b/marltoolbox/examples/rllib_api/dqn_coin_game.py
index ba32255..140139e 100644
--- a/marltoolbox/examples/rllib_api/dqn_coin_game.py
+++ b/marltoolbox/examples/rllib_api/dqn_coin_game.py
@@ -18,38 +18,36 @@ def main(debug):
seeds = miscellaneous.get_random_seeds(train_n_replicates)
exp_name, _ = log.log_in_current_day_dir("DQN_CG")
- hparams = _get_hyperparameters(seeds, debug, exp_name)
+ hparams = get_hyperparameters(seeds, debug, exp_name)
- rllib_config, stop_config = _get_rllib_configs(hparams)
+ rllib_config, stop_config = get_rllib_configs(hparams)
- tune_analysis = _train_dqn_and_plot_logs(
- hparams, rllib_config, stop_config)
+ tune_analysis = train_dqn_and_plot_logs(hparams, rllib_config, stop_config)
return tune_analysis
-def _get_hyperparameters(seeds, debug, exp_name):
+def get_hyperparameters(seeds, debug, exp_name):
+ """Get hyperparameters for the Coin Game env and DQN agents"""
+
hparams = {
"seeds": seeds,
"debug": debug,
"exp_name": exp_name,
"n_steps_per_epi": 100,
- "n_epi": 2000,
- "buf_frac": 0.125,
- "last_exploration_temp_value": 0.1,
- "bs_epi_mul": 1,
-
- "plot_keys":
- coin_game.PLOT_KEYS +
- aggregate_and_plot_tensorboard_data.PLOT_KEYS,
- "plot_assemblage_tags":
- coin_game.PLOT_ASSEMBLAGE_TAGS +
- aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS,
+ "n_epi": 4000,
+ "buf_frac": 0.5,
+ "last_exploration_temp_value": 0.003,
+ "bs_epi_mul": 4,
+ "plot_keys": coin_game.PLOT_KEYS
+ + aggregate_and_plot_tensorboard_data.PLOT_KEYS,
+ "plot_assemblage_tags": coin_game.PLOT_ASSEMBLAGE_TAGS
+ + aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS,
}
return hparams
-def _get_rllib_configs(hp, env_class=None):
+def get_rllib_configs(hp, env_class=None):
stop_config = {
"episodes_total": 2 if hp["debug"] else hp["n_epi"],
}
@@ -59,58 +57,70 @@ def _get_rllib_configs(hp, env_class=None):
"max_steps": hp["n_steps_per_epi"],
"grid_size": 3,
"get_additional_info": True,
+ "buf_frac": hp["buf_frac"],
+ "bs_epi_mul": hp["bs_epi_mul"],
}
env_class = coin_game.CoinGame if env_class is None else env_class
rllib_config = {
"env": env_class,
"env_config": env_config,
-
"multiagent": {
"policies": {
env_config["players_ids"][0]: (
augmented_dqn.MyDQNTorchPolicy,
env_class(env_config).OBSERVATION_SPACE,
env_class.ACTION_SPACE,
- {}),
+ {},
+ ),
env_config["players_ids"][1]: (
augmented_dqn.MyDQNTorchPolicy,
env_class(env_config).OBSERVATION_SPACE,
env_class.ACTION_SPACE,
- {}),
+ {},
+ ),
},
"policy_mapping_fn": lambda agent_id: agent_id,
},
-
# === DQN Models ===
-
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": tune.sample_from(
- lambda spec: int(spec.config["env_config"]["max_steps"] * 30)),
+ lambda spec: int(spec.config["env_config"]["max_steps"] * 30)
+ ),
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": tune.sample_from(
- lambda spec: int(spec.config["env_config"]["max_steps"] *
- spec.stop["episodes_total"] * hp["buf_frac"])),
+ lambda spec: int(
+ spec.config["env_config"]["max_steps"]
+ * spec.stop["episodes_total"]
+ * spec.config["env_config"]["buf_frac"]
+ )
+ ),
# Whether to use dueling dqn
- "dueling": False,
+ "dueling": True,
# Whether to use double dqn
"double_q": True,
# If True prioritized replay buffer will be used.
"prioritized_replay": False,
-
"rollout_fragment_length": tune.sample_from(
- lambda spec: spec.config["env_config"]["max_steps"]),
- "training_intensity": 10,
+ lambda spec: spec.config["env_config"]["max_steps"]
+ ),
+ "training_intensity": tune.sample_from(
+ lambda spec: spec.config["num_envs_per_worker"]
+ * max(spec.config["num_workers"], 1)
+ * 40
+ ),
# Size of a batch sampled from replay buffer for training. Note that
# if async_updates is set, then each worker returns gradients for a
# batch of this size.
"train_batch_size": tune.sample_from(
- lambda spec: int(spec.config["env_config"]["max_steps"] *
- hp["bs_epi_mul"])),
+ lambda spec: int(
+ spec.config["env_config"]["max_steps"]
+ * spec.config["env_config"]["bs_epi_mul"]
+ )
+ ),
"batch_mode": "complete_episodes",
-
# === Exploration Settings ===
# Default exploration behavior, iff `explore`=None is passed into
# compute_action(s).
@@ -130,61 +140,85 @@ def _get_rllib_configs(hp, env_class=None):
"temperature_schedule": tune.sample_from(
lambda spec: PiecewiseSchedule(
endpoints=[
- (0,
- 2.0),
- (int(spec.config["env_config"]["max_steps"] *
- spec.stop["episodes_total"] * 0.20),
- 0.5),
- (int(spec.config["env_config"]["max_steps"] *
- spec.stop["episodes_total"] * 0.60),
- hp["last_exploration_temp_value"])],
+ (0, 1.0),
+ (
+ int(
+ spec.config["env_config"]["max_steps"]
+ * spec.stop["episodes_total"]
+ * 0.20
+ ),
+ 0.45,
+ ),
+ (
+ int(
+ spec.config["env_config"]["max_steps"]
+ * spec.stop["episodes_total"]
+ * 0.9
+ ),
+ hp["last_exploration_temp_value"],
+ ),
+ ],
outside_value=hp["last_exploration_temp_value"],
- framework="torch")),
+ framework="torch",
+ )
+ ),
},
-
# Size of batches collected from each worker.
"model": {
"dim": env_config["grid_size"],
# [Channel, [Kernel, Kernel], Stride]]
- "conv_filters": [[16, [3, 3], 1], [32, [3, 3], 1]]
+ "conv_filters": [[64, [3, 3], 1], [64, [3, 3], 1]],
+ "fcnet_hiddens": [64, 64],
},
+ "hiddens": [32],
"gamma": 0.96,
- "optimizer": {"sgd_momentum": 0.9, },
+ "optimizer": {
+ "sgd_momentum": 0.9,
+ },
"lr": 0.1,
"lr_schedule": tune.sample_from(
lambda spec: [
(0, 0.0),
- (int(spec.config["env_config"]["max_steps"] *
- spec.stop["episodes_total"] * 0.05),
- spec.config.lr),
- (int(spec.config["env_config"]["max_steps"] *
- spec.stop["episodes_total"]),
- spec.config.lr / 1e9)
+ (
+ int(
+ spec.config["env_config"]["max_steps"]
+ * spec.stop["episodes_total"]
+ * 0.05
+ ),
+ spec.config.lr,
+ ),
+ (
+ int(
+ spec.config["env_config"]["max_steps"]
+ * spec.stop["episodes_total"]
+ ),
+ spec.config.lr / 1e9,
+ ),
]
),
-
"seed": tune.grid_search(hp["seeds"]),
"callbacks": log.get_logging_callbacks_class(),
"framework": "torch",
-
"logger_config": {
"wandb": {
"project": "DQN_CG",
"group": hp["exp_name"],
- "api_key_file":
- os.path.join(os.path.dirname(__file__),
- "../../../api_key_wandb"),
- "log_config": True
+ "api_key_file": os.path.join(
+ os.path.dirname(__file__), "../../../api_key_wandb"
+ ),
+ "log_config": True,
},
},
-
+ "num_envs_per_worker": 16,
+ "num_workers": 0,
+ # "log_level": "INFO",
}
return rllib_config, stop_config
-def _train_dqn_and_plot_logs(hp, rllib_config, stop_config):
- ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=hp["debug"])
+def train_dqn_and_plot_logs(hp, rllib_config, stop_config):
+ ray.init(num_cpus=os.cpu_count(), local_mode=hp["debug"])
tune_analysis = tune.run(
DQNTrainer,
config=rllib_config,
diff --git a/marltoolbox/examples/rllib_api/dqn_wt_welfare.py b/marltoolbox/examples/rllib_api/dqn_wt_welfare.py
index 10a7083..601d20c 100644
--- a/marltoolbox/examples/rllib_api/dqn_wt_welfare.py
+++ b/marltoolbox/examples/rllib_api/dqn_wt_welfare.py
@@ -10,29 +10,35 @@ def main(debug, welfare=postprocessing.WELFARE_UTILITARIAN):
seeds = miscellaneous.get_random_seeds(train_n_replicates)
exp_name, _ = log.log_in_current_day_dir("DQN_welfare_CG")
- hparams = dqn_coin_game._get_hyperparameters(seeds, debug, exp_name)
- rllib_config, stop_config = dqn_coin_game._get_rllib_configs(hparams)
- rllib_config = _modify_policy_to_use_welfare(rllib_config, welfare)
+ hparams = dqn_coin_game.get_hyperparameters(seeds, debug, exp_name)
+ rllib_config, stop_config = dqn_coin_game.get_rllib_configs(hparams)
+ rllib_config = modify_dqn_rllib_config_to_use_welfare(
+ rllib_config, welfare
+ )
- tune_analysis = dqn_coin_game._train_dqn_and_plot_logs(
- hparams, rllib_config, stop_config)
+ tune_analysis = dqn_coin_game.train_dqn_and_plot_logs(
+ hparams, rllib_config, stop_config
+ )
return tune_analysis
-def _modify_policy_to_use_welfare(rllib_config, welfare):
- MyCoopDQNTorchPolicy = augmented_dqn.MyDQNTorchPolicy.with_updates(
- postprocess_fn=miscellaneous.merge_policy_postprocessing_fn(
- postprocessing.welfares_postprocessing_fn(),
- postprocess_nstep_and_prio,
- )
+def modify_dqn_rllib_config_to_use_welfare(rllib_config, welfare):
+ DQNTorchPolicyWtWelfare = _get_policy_class_wt_welfare_preprocessing()
+ rllib_config = modify_rllib_config_to_use_welfare(
+ rllib_config, welfare, policy_class_wt_welfare=DQNTorchPolicyWtWelfare
)
+ return rllib_config
+
+def modify_rllib_config_to_use_welfare(
+ rllib_config, welfare, policy_class_wt_welfare, overwrite_reward=True
+):
policies = rllib_config["multiagent"]["policies"]
new_policies = {}
for policies_id, policy_tuple in policies.items():
new_policies[policies_id] = list(policy_tuple)
- new_policies[policies_id][0] = MyCoopDQNTorchPolicy
+ new_policies[policies_id][0] = policy_class_wt_welfare
if welfare == postprocessing.WELFARE_UTILITARIAN:
new_policies[policies_id][3].update(
{postprocessing.ADD_UTILITARIAN_WELFARE: True}
@@ -40,7 +46,7 @@ def _modify_policy_to_use_welfare(rllib_config, welfare):
elif welfare == postprocessing.WELFARE_INEQUITY_AVERSION:
add_ia_w = True
ia_alpha = 0.0
- ia_beta = 0.5
+ ia_beta = 0.5 / 2
ia_gamma = 0.96
ia_lambda = 0.96
inequity_aversion_parameters = (
@@ -51,18 +57,30 @@ def _modify_policy_to_use_welfare(rllib_config, welfare):
ia_lambda,
)
new_policies[policies_id][3].update(
- {postprocessing.ADD_INEQUITY_AVERSION_WELFARE:
- inequity_aversion_parameters}
+ {
+ postprocessing.ADD_INEQUITY_AVERSION_WELFARE: inequity_aversion_parameters
+ }
)
rllib_config["multiagent"]["policies"] = new_policies
- rllib_config["callbacks"] = callbacks.merge_callbacks(
- log.get_logging_callbacks_class(),
- postprocessing.OverwriteRewardWtWelfareCallback,
- )
+ if overwrite_reward:
+ rllib_config["callbacks"] = callbacks.merge_callbacks(
+ log.get_logging_callbacks_class(),
+ postprocessing.OverwriteRewardWtWelfareCallback,
+ )
return rllib_config
+def _get_policy_class_wt_welfare_preprocessing():
+ DQNTorchPolicyWtWelfare = augmented_dqn.MyDQNTorchPolicy.with_updates(
+ postprocess_fn=miscellaneous.merge_policy_postprocessing_fn(
+ postprocessing.welfares_postprocessing_fn(),
+ postprocess_nstep_and_prio,
+ )
+ )
+ return DQNTorchPolicyWtWelfare
+
+
if __name__ == "__main__":
debug_mode = True
main(debug_mode)
diff --git a/marltoolbox/examples/rllib_api/inequity_aversion.py b/marltoolbox/examples/rllib_api/inequity_aversion.py
index c730ea2..9e679dc 100644
--- a/marltoolbox/examples/rllib_api/inequity_aversion.py
+++ b/marltoolbox/examples/rllib_api/inequity_aversion.py
@@ -21,39 +21,35 @@ def main(debug):
}
policies = {
- env_config["players_ids"][0]:
- (
- None,
- IteratedBoSAndPD.OBSERVATION_SPACE,
- IteratedBoSAndPD.ACTION_SPACE,
- {}
- ),
- env_config["players_ids"][1]:
- (
- None,
- IteratedBoSAndPD.OBSERVATION_SPACE,
- IteratedBoSAndPD.ACTION_SPACE,
- {}
- )}
+ env_config["players_ids"][0]: (
+ None,
+ IteratedBoSAndPD.OBSERVATION_SPACE,
+ IteratedBoSAndPD.ACTION_SPACE,
+ {},
+ ),
+ env_config["players_ids"][1]: (
+ None,
+ IteratedBoSAndPD.OBSERVATION_SPACE,
+ IteratedBoSAndPD.ACTION_SPACE,
+ {},
+ ),
+ }
rllib_config = {
"env": IteratedBoSAndPD,
"env_config": env_config,
-
"num_gpus": 0,
"num_workers": 1,
-
"multiagent": {
"policies": policies,
"policy_mapping_fn": (lambda agent_id: agent_id),
},
"framework": "torch",
"gamma": 0.5,
-
"callbacks": callbacks.merge_callbacks(
log.get_logging_callbacks_class(),
- postprocessing.OverwriteRewardWtWelfareCallback),
-
+ postprocessing.OverwriteRewardWtWelfareCallback,
+ ),
}
MyPGTorchPolicy = PGTorchPolicy.with_updates(
@@ -63,18 +59,21 @@ def main(debug):
inequity_aversion_beta=1.0,
inequity_aversion_alpha=0.0,
inequity_aversion_gamma=1.0,
- inequity_aversion_lambda=0.5
+ inequity_aversion_lambda=0.5,
),
- pg_torch_policy.post_process_advantages
+ pg_torch_policy.post_process_advantages,
)
)
- MyPGTrainer = PGTrainer.with_updates(default_policy=MyPGTorchPolicy,
- get_policy_class=None)
- tune_analysis = tune.run(MyPGTrainer,
- stop=stop,
- checkpoint_freq=10,
- config=rllib_config,
- name=exp_name)
+ MyPGTrainer = PGTrainer.with_updates(
+ default_policy=MyPGTorchPolicy, get_policy_class=None
+ )
+ tune_analysis = tune.run(
+ MyPGTrainer,
+ stop=stop,
+ checkpoint_freq=10,
+ config=rllib_config,
+ name=exp_name,
+ )
ray.shutdown()
return tune_analysis
diff --git a/marltoolbox/examples/rllib_api/pg_ipd.py b/marltoolbox/examples/rllib_api/pg_ipd.py
index 769843d..89bb6df 100644
--- a/marltoolbox/examples/rllib_api/pg_ipd.py
+++ b/marltoolbox/examples/rllib_api/pg_ipd.py
@@ -2,7 +2,8 @@
import ray
from ray import tune
-from ray.rllib.agents.pg import PGTrainer
+from ray.rllib.agents.pg import PGTrainer, PGTorchPolicy
+from ray.rllib.agents.pg.pg_torch_policy import pg_loss_stats
from marltoolbox.envs.matrix_sequential_social_dilemma import (
IteratedPrisonersDilemma,
@@ -10,14 +11,14 @@
from marltoolbox.utils import log, miscellaneous
-def main(debug, stop_iters=300, tf=False):
+def main(debug):
train_n_replicates = 1 if debug else 1
seeds = miscellaneous.get_random_seeds(train_n_replicates)
exp_name, _ = log.log_in_current_day_dir("PG_IPD")
ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug)
- rllib_config, stop_config = get_rllib_config(seeds, debug, stop_iters, tf)
+ rllib_config, stop_config = get_rllib_config(seeds, debug)
tune_analysis = tune.run(
PGTrainer,
config=rllib_config,
@@ -30,14 +31,16 @@ def main(debug, stop_iters=300, tf=False):
return tune_analysis
-def get_rllib_config(seeds, debug=False, stop_iters=300, tf=False):
+def get_rllib_config(seeds, debug=False):
stop_config = {
- "training_iteration": 2 if debug else stop_iters,
+ "episodes_total": 2 if debug else 400,
}
+ n_steps_in_epi = 20
+
env_config = {
"players_ids": ["player_row", "player_col"],
- "max_steps": 20,
+ "max_steps": n_steps_in_epi,
"get_additional_info": True,
}
@@ -63,8 +66,9 @@ def get_rllib_config(seeds, debug=False, stop_iters=300, tf=False):
},
"seed": tune.grid_search(seeds),
"callbacks": log.get_logging_callbacks_class(log_full_epi=True),
- "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
- "framework": "tf" if tf else "torch",
+ "framework": "torch",
+ "rollout_fragment_length": n_steps_in_epi,
+ "train_batch_size": n_steps_in_epi,
}
return rllib_config, stop_config
diff --git a/marltoolbox/examples/rllib_api/ppo_coin_game.py b/marltoolbox/examples/rllib_api/ppo_coin_game.py
index 78c10a1..3243b73 100644
--- a/marltoolbox/examples/rllib_api/ppo_coin_game.py
+++ b/marltoolbox/examples/rllib_api/ppo_coin_game.py
@@ -8,12 +8,8 @@
from marltoolbox.envs.coin_game import AsymCoinGame
from marltoolbox.utils import log, miscellaneous
-parser = argparse.ArgumentParser()
-parser.add_argument("--tf", action="store_true")
-parser.add_argument("--stop-iters", type=int, default=2000)
-
-def main(debug, stop_iters=2000, tf=False):
+def main(debug):
train_n_replicates = 1 if debug else 1
seeds = miscellaneous.get_random_seeds(train_n_replicates)
exp_name, _ = log.log_in_current_day_dir("PPO_AsymCG")
@@ -21,7 +17,7 @@ def main(debug, stop_iters=2000, tf=False):
ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug)
stop = {
- "training_iteration": 2 if debug else stop_iters,
+ "training_iteration": 2 if debug else 2000,
}
env_class = AsymCoinGame
@@ -66,7 +62,7 @@ def main(debug, stop_iters=2000, tf=False):
"seed": tune.grid_search(seeds),
"callbacks": log.get_logging_callbacks_class(),
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
- "framework": "tf" if tf else "torch",
+ "framework": "torch",
"num_workers": 0,
}
@@ -83,6 +79,5 @@ def main(debug, stop_iters=2000, tf=False):
if __name__ == "__main__":
- args = parser.parse_args()
debug_mode = True
- main(debug_mode, args.stop_iters, args.tf)
+ main(debug_mode)
diff --git a/marltoolbox/examples/rllib_api/r2d2_coin_game.py b/marltoolbox/examples/rllib_api/r2d2_coin_game.py
new file mode 100644
index 0000000..454f32d
--- /dev/null
+++ b/marltoolbox/examples/rllib_api/r2d2_coin_game.py
@@ -0,0 +1,249 @@
+import os
+
+import ray
+from ray import tune
+from ray.rllib.agents import dqn
+from ray.tune.integration.wandb import WandbLogger
+from ray.tune.logger import DEFAULT_LOGGERS
+from ray.rllib.agents.dqn.r2d2_torch_policy import postprocess_nstep_and_prio
+from ray.rllib.utils.schedules import PiecewiseSchedule
+
+from marltoolbox.examples.rllib_api import dqn_coin_game, dqn_wt_welfare
+from marltoolbox.scripts import aggregate_and_plot_tensorboard_data
+from marltoolbox.utils import log, miscellaneous, postprocessing
+from marltoolbox.algos import augmented_r2d2
+from marltoolbox.envs import coin_game, ssd_mixed_motive_coin_game
+
+
+def main(debug):
+ """Train R2D2 agent in the Coin Game environment"""
+
+ # env = "CoinGame"
+ env = "SSDMixedMotiveCoinGame"
+ # welfare_to_use = None
+ welfare_to_use = postprocessing.WELFARE_UTILITARIAN
+ # welfare_to_use = postprocessing.WELFARE_INEQUITY_AVERSION
+
+ rllib_config, stop_config, hparams = _get_config_and_hp_for_training(
+ debug, env, welfare_to_use
+ )
+
+ tune_analysis = _train_dqn(hparams, rllib_config, stop_config)
+
+ _plot_log_aggregates(hparams)
+
+ return tune_analysis
+
+
+def _get_config_and_hp_for_training(debug, env, welfare_to_use):
+ train_n_replicates = 1 if debug else 1
+ seeds = miscellaneous.get_random_seeds(train_n_replicates)
+ exp_name, _ = log.log_in_current_day_dir("R2D2_CG")
+ if "SSDMixedMotiveCoinGame" in env:
+ env_class = ssd_mixed_motive_coin_game.SSDMixedMotiveCoinGame
+ else:
+ env_class = coin_game.CoinGame
+
+ hparams = dqn_coin_game.get_hyperparameters(seeds, debug, exp_name)
+ rllib_config, stop_config = dqn_coin_game.get_rllib_configs(
+ hparams, env_class=env_class
+ )
+ rllib_config, stop_config = _adapt_configs_for_r2d2(
+ rllib_config, stop_config, hparams
+ )
+ if welfare_to_use is not None:
+ rllib_config = modify_r2d2_rllib_config_to_use_welfare(
+ rllib_config, welfare_to_use
+ )
+ return rllib_config, stop_config, hparams
+
+
+def modify_r2d2_rllib_config_to_use_welfare(rllib_config, welfare_to_use):
+ r2d2_torch_policy_class_wt_welfare = (
+ _get_r2d2_policy_class_wt_welfare_preprocessing()
+ )
+ rllib_config = dqn_wt_welfare.modify_rllib_config_to_use_welfare(
+ rllib_config,
+ welfare_to_use,
+ policy_class_wt_welfare=r2d2_torch_policy_class_wt_welfare,
+ )
+ return rllib_config
+
+
+def _get_r2d2_policy_class_wt_welfare_preprocessing():
+ r2d2_torch_policy_class_wt_welfare = (
+ augmented_r2d2.MyR2D2TorchPolicy.with_updates(
+ postprocess_fn=miscellaneous.merge_policy_postprocessing_fn(
+ postprocessing.welfares_postprocessing_fn(),
+ postprocess_nstep_and_prio,
+ )
+ )
+ )
+ return r2d2_torch_policy_class_wt_welfare
+
+
+def _adapt_configs_for_r2d2(rllib_config, stop_config, hp):
+ rllib_config["logger_config"]["wandb"]["project"] = "R2D2_CG"
+ rllib_config["model"]["use_lstm"] = True
+ rllib_config["burn_in"] = 0
+ rllib_config["zero_init_states"] = False
+ rllib_config["use_h_function"] = False
+
+ rllib_config = _replace_class_of_policies_by(
+ augmented_r2d2.MyR2D2TorchPolicy,
+ rllib_config,
+ )
+
+ if not hp["debug"]:
+ # rllib_config["env_config"]["training_intensity"] = 40
+ # rllib_config["lr"] = 0.1
+ # rllib_config["training_intensity"] = tune.sample_from(
+ # lambda spec: spec.config["num_envs_per_worker"]
+ # * spec.config["env_config"]["training_intensity"]
+ # * max(1, spec.config["num_workers"])
+ # )
+ stop_config["episodes_total"] = 8000
+ rllib_config["model"]["lstm_cell_size"] = 16
+ rllib_config["model"]["max_seq_len"] = 20
+ # rllib_config["env_config"]["buf_frac"] = 0.5
+ # rllib_config["hiddens"] = [32]
+ # rllib_config["model"]["fcnet_hiddens"] = [64, 64]
+ # rllib_config["model"]["conv_filters"] = tune.grid_search(
+ # [
+ # [[32, [3, 3], 1], [32, [3, 3], 1]],
+ # [[64, [3, 3], 1], [64, [3, 3], 1]],
+ # ]
+ # )
+
+ rllib_config["env_config"]["temp_ratio"] = (
+ 1.0 if hp["debug"] else tune.grid_search([0.75, 1.0])
+ )
+ rllib_config["env_config"]["interm_temp_ratio"] = (
+ 1.0 if hp["debug"] else tune.grid_search([0.75, 1.0])
+ )
+ rllib_config["env_config"]["last_exploration_t"] = (
+ 0.6 if hp["debug"] else tune.grid_search([0.9, 0.99])
+ )
+ rllib_config["env_config"]["last_exploration_temp_value"] = (
+ 1.0 if hp["debug"] else 0.003
+ )
+ rllib_config["env_config"]["interm_exploration_t"] = (
+ 0.2 if hp["debug"] else tune.grid_search([0.2, 0.6])
+ )
+ rllib_config["exploration_config"][
+ "temperature_schedule"
+ ] = tune.sample_from(
+ lambda spec: PiecewiseSchedule(
+ endpoints=[
+ (0, 1.0 * spec.config["env_config"]["temp_ratio"]),
+ (
+ int(
+ spec.config["env_config"]["max_steps"]
+ * spec.stop["episodes_total"]
+ * spec.config["env_config"]["interm_exploration_t"]
+ ),
+ 0.6
+ * spec.config["env_config"]["temp_ratio"]
+ * spec.config["env_config"]["interm_temp_ratio"],
+ ),
+ (
+ int(
+ spec.config["env_config"]["max_steps"]
+ * spec.stop["episodes_total"]
+ * spec.config["env_config"]["last_exploration_t"]
+ ),
+ spec.config["env_config"][
+ "last_exploration_temp_value"
+ ],
+ ),
+ ],
+ outside_value=spec.config["env_config"][
+ "last_exploration_temp_value"
+ ],
+ framework="torch",
+ )
+ )
+
+ # rllib_config["env_config"]["bs_epi_mul"] = (
+ # 4 if hp["debug"] else tune.grid_search([4, 8, 16])
+ # )
+ rllib_config["env_config"]["interm_lr_ratio"] = (
+ 0.5
+ if hp["debug"]
+ else tune.grid_search([0.5 * 3, 0.5, 0.5 / 3, 0.5 / 9])
+ )
+ rllib_config["env_config"]["interm_lr_t"] = (
+ 0.5 if hp["debug"] else tune.grid_search([0.25, 0.5, 0.75])
+ )
+ rllib_config["lr"] = 0.1 if hp["debug"] else 0.1
+ rllib_config["lr_schedule"] = tune.sample_from(
+ lambda spec: [
+ (0, 0.0),
+ (
+ int(
+ spec.config["env_config"]["max_steps"]
+ * spec.stop["episodes_total"]
+ * 0.05
+ ),
+ spec.config.lr,
+ ),
+ (
+ int(
+ spec.config["env_config"]["max_steps"]
+ * spec.stop["episodes_total"]
+ * spec.config["env_config"]["interm_lr_t"]
+ ),
+ spec.config.lr
+ * spec.config["env_config"]["interm_lr_ratio"],
+ ),
+ (
+ int(
+ spec.config["env_config"]["max_steps"]
+ * spec.stop["episodes_total"]
+ ),
+ spec.config.lr / 1e9,
+ ),
+ ]
+ )
+ else:
+ rllib_config["model"]["max_seq_len"] = 2
+ rllib_config["model"]["lstm_cell_size"] = 8
+
+ return rllib_config, stop_config
+
+
+def _replace_class_of_policies_by(new_policy_class, rllib_config):
+ policies = rllib_config["multiagent"]["policies"]
+ for policy_id in policies.keys():
+ policy = list(policies[policy_id])
+ policy[0] = new_policy_class
+ policies[policy_id] = tuple(policy)
+ return rllib_config
+
+
+def _train_dqn(hp, rllib_config, stop_config):
+ ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=hp["debug"])
+ tune_analysis = tune.run(
+ dqn.r2d2.R2D2Trainer,
+ config=rllib_config,
+ stop=stop_config,
+ name=hp["exp_name"],
+ log_to_file=not hp["debug"],
+ loggers=None if hp["debug"] else DEFAULT_LOGGERS + (WandbLogger,),
+ )
+ ray.shutdown()
+ return tune_analysis
+
+
+def _plot_log_aggregates(hp):
+ if not hp["debug"]:
+ aggregate_and_plot_tensorboard_data.add_summary_plots(
+ main_path=os.path.join("~/ray_results/", hp["exp_name"]),
+ plot_keys=hp["plot_keys"],
+ plot_assemble_tags_in_one_plot=hp["plot_assemblage_tags"],
+ )
+
+
+if __name__ == "__main__":
+ debug_mode = True
+ main(debug_mode)
diff --git a/marltoolbox/examples/rllib_api/r2d2_ipd.py b/marltoolbox/examples/rllib_api/r2d2_ipd.py
new file mode 100644
index 0000000..5d6d0bd
--- /dev/null
+++ b/marltoolbox/examples/rllib_api/r2d2_ipd.py
@@ -0,0 +1,68 @@
+import os
+
+import ray
+from ray import tune
+from ray.rllib.agents.dqn import R2D2Trainer
+
+from marltoolbox.algos import augmented_dqn
+from marltoolbox.envs import matrix_sequential_social_dilemma
+from marltoolbox.examples.rllib_api import pg_ipd
+from marltoolbox.scripts import aggregate_and_plot_tensorboard_data
+from marltoolbox.utils import log, miscellaneous
+
+
+def main(debug):
+ train_n_replicates = 1 if debug else 1
+ seeds = miscellaneous.get_random_seeds(train_n_replicates)
+ exp_name, _ = log.log_in_current_day_dir("R2D2_IPD")
+
+ ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug)
+
+ rllib_config, stop_config = pg_ipd.get_rllib_config(seeds, debug)
+ rllib_config, stop_config = _adapt_configs_for_r2d2(
+ rllib_config, stop_config, debug
+ )
+
+ tune_analysis = tune.run(
+ R2D2Trainer,
+ config=rllib_config,
+ stop=stop_config,
+ checkpoint_at_end=True,
+ name=exp_name,
+ log_to_file=True,
+ )
+
+ if not debug:
+ _plot_log_aggregates(exp_name)
+
+ ray.shutdown()
+ return tune_analysis
+
+
+def _adapt_configs_for_r2d2(rllib_config, stop_config, debug):
+ rllib_config["model"] = {"use_lstm": True}
+ stop_config["episodes_total"] = 10 if debug else 600
+
+ return rllib_config, stop_config
+
+
+def _plot_log_aggregates(exp_name):
+ plot_keys = (
+ aggregate_and_plot_tensorboard_data.PLOT_KEYS
+ + matrix_sequential_social_dilemma.PLOT_KEYS
+ + augmented_dqn.PL
+ )
+ plot_assemble_tags_in_one_plot = (
+ aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS
+ + matrix_sequential_social_dilemma.PLOT_ASSEMBLAGE_TAGS
+ )
+ aggregate_and_plot_tensorboard_data.add_summary_plots(
+ main_path=os.path.join("~/ray_results/", exp_name),
+ plot_keys=plot_keys,
+ plot_assemble_tags_in_one_plot=plot_assemble_tags_in_one_plot,
+ )
+
+
+if __name__ == "__main__":
+ debug_mode = True
+ main(debug_mode)
diff --git a/marltoolbox/experiments/rllib_api/amtft_meta_game.py b/marltoolbox/experiments/rllib_api/amtft_meta_game.py
index 225aba3..1019a08 100644
--- a/marltoolbox/experiments/rllib_api/amtft_meta_game.py
+++ b/marltoolbox/experiments/rllib_api/amtft_meta_game.py
@@ -1,8 +1,9 @@
+import collections
import copy
import json
import logging
import os
-import random
+from typing import List, Dict
import numpy as np
import pandas as pd
@@ -11,91 +12,120 @@
from ray.rllib.agents import dqn
from ray.rllib.utils import merge_dicts
+from marltoolbox import utils
from marltoolbox.algos import welfare_coordination
from marltoolbox.experiments.rllib_api import amtft_various_env
from marltoolbox.utils import (
- self_and_cross_perf,
postprocessing,
- miscellaneous,
plot,
+ path,
+ cross_play,
+ tune_analysis,
+ restore,
)
logger = logging.getLogger(__name__)
def main(debug):
- if debug:
- test_meta_solver(debug)
-
- hp = _get_hyperparameters(debug)
+ # _extract_stats_on_welfare_announced(
+ # exp_dir="/home/maxime/dev-maxime/CLR/vm-data/instance-60-cpu-3-preemtible/amTFT/2021_04_23/12_58_42",
+ # players_ids=["player_row", "player_col"],
+ # )
+ #
+ # if debug:
+ # test_meta_solver(debug)
+ use_r2d2 = True
+ hp = get_hyperparameters(debug, use_r2d2=use_r2d2)
results = []
+ ray.init(num_cpus=os.cpu_count(), local_mode=hp["debug"])
for tau in hp["tau_range"]:
+ hp["tau"] = tau
(
all_rllib_config,
hp_eval,
env_config,
stop_config,
- ) = _produce_rllib_config_for_each_replicates(tau, hp)
+ ) = _produce_rllib_config_for_each_replicates(hp)
mixed_rllib_configs = _mix_rllib_config(all_rllib_config, hp_eval)
-
tune_analysis = ray.tune.run(
- dqn.DQNTrainer,
+ hp["trainer"],
config=mixed_rllib_configs,
verbose=1,
stop=stop_config,
- checkpoint_at_end=True,
name=hp_eval["exp_name"],
- log_to_file=not hp_eval["debug"],
- # loggers=None
- # if hp_eval["debug"]
- # else DEFAULT_LOGGERS + (WandbLogger,),
- )
- mean_player_1_payoffs, mean_player_2_payoffs = _extract_metrics(
- tune_analysis, hp_eval
+ # log_to_file=not hp_eval["debug"],
)
+ (
+ mean_player_1_payoffs,
+ mean_player_2_payoffs,
+ player_1_payoffs,
+ player_2_payoffs,
+ ) = extract_metrics(tune_analysis, hp_eval)
results.append(
(
tau,
(mean_player_1_payoffs, mean_player_2_payoffs),
+ (player_1_payoffs, player_2_payoffs),
)
)
- _save_to_json(exp_name=hp["exp_name"], object=results)
- _plot_results(exp_name=hp["exp_name"], results=results, hp_eval=hp_eval)
+ save_to_json(exp_name=hp["exp_name"], object=results)
+ plot_results(
+ exp_name=hp["exp_name"],
+ results=results,
+ hp_eval=hp_eval,
+ format_fn=format_result_for_plotting,
+ )
+ extract_stats_on_welfare_announced(
+ exp_name=hp["exp_name"], players_ids=env_config["players_ids"]
+ )
+
+def get_hyperparameters(debug, use_r2d2=False):
+ """Get hyperparameters for meta game with amTFT policies in base game"""
-def _get_hyperparameters(debug):
# env = "IteratedPrisonersDilemma"
env = "IteratedAsymBoS"
# env = "CoinGame"
hp = amtft_various_env.get_hyperparameters(
- debug, train_n_replicates=1, filter_utilitarian=False, env=env
+ debug,
+ train_n_replicates=1,
+ filter_utilitarian=False,
+ env=env,
+ use_r2d2=use_r2d2,
)
hp.update(
{
- "n_replicates_over_full_exp": 2,
- "final_base_game_eval_over_n_epi": 2,
+ "n_replicates_over_full_exp": 2 if debug else 5,
+ "final_base_game_eval_over_n_epi": 1 if debug else 20,
"tau_range": np.arange(0.0, 1.1, 0.5)
if hp["debug"]
else np.arange(0.0, 1.1, 0.1),
"n_self_play_in_final_meta_game": 0,
- "n_cross_play_in_final_meta_game": 1,
+ "n_cross_play_in_final_meta_game": 1 if debug else 4,
+ "max_n_replicates_in_base_game": 10,
}
)
+
+ if use_r2d2:
+ hp["trainer"] = dqn.r2d2.R2D2Trainer
+ else:
+ hp["trainer"] = dqn.DQNTrainer
+
return hp
-def _produce_rllib_config_for_each_replicates(tau, hp):
+def _produce_rllib_config_for_each_replicates(hp):
all_rllib_config = []
for replicate_i in range(hp["n_replicates_over_full_exp"]):
hp_eval = _load_base_game_results(
copy.deepcopy(hp), load_base_replicate_i=replicate_i
)
- hp_eval["tau"] = tau
(
rllib_config,
@@ -110,29 +140,55 @@ def _produce_rllib_config_for_each_replicates(tau, hp):
rllib_config, env_config, hp_eval
)
all_rllib_config.append(rllib_config)
+
return all_rllib_config, hp_eval, env_config, stop_config
def _load_base_game_results(hp, load_base_replicate_i):
- prefix = "/home/maxime/dev-maxime/CLR/vm-data/instance-10-cpu-2/"
- # prefix = "/ray_results/"
+ # prefix = "~/dev-maxime/CLR/vm-data/instance-10-cpu-2/"
+ # prefix = "~/dev-maxime/CLR/vm-data/instance-60-cpu-1-preemtible/"
+ prefix = "~/dev-maxime/CLR/vm-data/instance-60-cpu-2-preemtible/"
+ # prefix = "~/ray_results/"
+ prefix = os.path.expanduser(prefix)
if "CoinGame" in hp["env_name"]:
hp["data_dir"] = (prefix + "amTFT/2021_04_10/19_37_20",)[
load_base_replicate_i
]
elif "IteratedAsymBoS" in hp["env_name"]:
hp["data_dir"] = (
- prefix + "amTFT/2021_04_13/11_56_23", # 5 replicates
- prefix + "amTFT/2021_04_13/13_40_03", # 5 replicates
- prefix + "amTFT/2021_04_13/13_40_34", # 5 replicates
- prefix + "amTFT/2021_04_13/18_06_48", # 10 replicates
- prefix + "amTFT/2021_04_13/18_07_05", # 10 replicates
+ # Before fix + without R2D2
+ # prefix + "amTFT/2021_04_13/11_56_23", # 5 replicates
+ # prefix + "amTFT/2021_04_13/13_40_03", # 5 replicates
+ # prefix + "amTFT/2021_04_13/13_40_34", # 5 replicates
+ # prefix + "amTFT/2021_04_13/18_06_48", # 10 replicates
+ # prefix + "amTFT/2021_04_13/18_07_05", # 10 replicates
+ # After env fix + with R2D2
+ # prefix + "amTFT/2021_05_03/13_44_48", # 10 replicates
+ # prefix + "amTFT/2021_05_03/13_45_09", # 10 replicates
+ # prefix + "amTFT/2021_05_03/13_47_00", # 10 replicates
+ # prefix + "amTFT/2021_05_03/13_48_03", # 10 replicates
+ # prefix + "amTFT/2021_05_03/13_49_28", # 10 replicates
+ # prefix + "amTFT/2021_05_03/13_49_47", # 10 replicates
+ # After fix env + short debit rollout + punish instead of
+ # selfish + 20 rllout length insead of 6
+ prefix + "amTFT/2021_05_03/13_52_07", # 10 replicates
+ prefix + "amTFT/2021_05_03/13_52_27", # 10 replicates
+ prefix + "amTFT/2021_05_03/13_53_37", # 10 replicates
+ ### prefix + "amTFT/2021_05_03/13_54_08", # 10 replicates
+ prefix + "amTFT/2021_05_03/13_54_55", # 10 replicates
+ prefix + "amTFT/2021_05_03/13_56_35", # 10 replicates
)[load_base_replicate_i]
elif "IteratedPrisonersDilemma" in hp["env_name"]:
hp["data_dir"] = (
"/home/maxime/dev-maxime/CLR/vm-data/"
"instance-10-cpu-4/amTFT/2021_04_13/12_12_56",
)[load_base_replicate_i]
+ else:
+ raise ValueError()
+
+ assert os.path.exists(
+ hp["data_dir"]
+ ), "Path doesn't exist. Probably that the prefix need to be changed to fit the current machine used"
hp["json_file"] = _get_results_json_file_path_in_dir(hp["data_dir"])
hp["ckpt_per_welfare"] = _get_checkpoints_for_each_welfare_in_dir(
@@ -150,17 +206,17 @@ def _get_vanilla_amTFT_eval_config(hp, final_eval_over_n_epi):
stop_config, env_config, rllib_config = amtft_various_env.get_rllib_config(
hp_eval, welfare_fn=postprocessing.WELFARE_INEQUITY_AVERSION, eval=True
)
- rllib_config = amtft_various_env.modify_config_for_evaluation(
- rllib_config, hp_eval, env_config
+ rllib_config, stop_config = amtft_various_env.modify_config_for_evaluation(
+ rllib_config, hp_eval, env_config, stop_config
)
hp_eval["n_self_play_per_checkpoint"] = None
hp_eval["n_cross_play_per_checkpoint"] = None
hp_eval[
"x_axis_metric"
- ] = f"policy_reward_mean.{env_config['players_ids'][0]}"
+ ] = f"policy_reward_mean/{env_config['players_ids'][0]}"
hp_eval[
"y_axis_metric"
- ] = f"policy_reward_mean.{env_config['players_ids'][1]}"
+ ] = f"policy_reward_mean/{env_config['players_ids'][1]}"
return rllib_config, hp_eval, env_config, stop_config
@@ -168,8 +224,10 @@ def _get_vanilla_amTFT_eval_config(hp, final_eval_over_n_epi):
def _modify_config_to_use_welfare_coordinators(
rllib_config, env_config, hp_eval
):
- all_welfare_pairs_wt_payoffs = _get_all_welfare_pairs_wt_payoffs(
- hp_eval, env_config["players_ids"]
+ all_welfare_pairs_wt_payoffs = (
+ _get_all_welfare_pairs_wt_cross_play_payoffs(
+ hp_eval, env_config["players_ids"]
+ )
)
rllib_config["multiagent"]["policies_to_train"] = ["None"]
@@ -192,17 +250,11 @@ def _modify_config_to_use_welfare_coordinators(
"all_welfare_pairs_wt_payoffs": all_welfare_pairs_wt_payoffs,
"own_player_idx": policy_idx,
"opp_player_idx": opp_policy_idx,
- # "own_default_welfare_fn": postprocessing.WELFARE_INEQUITY_AVERSION
- # if policy_idx
- # else postprocessing.WELFARE_UTILITARIAN,
- # "opp_default_welfare_fn": postprocessing.WELFARE_INEQUITY_AVERSION
- # if opp_policy_idx
- # else postprocessing.WELFARE_UTILITARIAN,
"own_default_welfare_fn": "inequity aversion"
- if policy_idx
+ if policy_idx == 1
else "utilitarian",
"opp_default_welfare_fn": "inequity aversion"
- if opp_policy_idx
+ if opp_policy_idx == 1
else "utilitarian",
"policy_id_to_load": policy_id,
"policy_checkpoints": hp_eval["ckpt_per_welfare"],
@@ -227,7 +279,7 @@ def _mix_rllib_config(all_rllib_configs, hp_eval):
hp_eval["n_self_play_in_final_meta_game"] == 0
and hp_eval["n_cross_play_in_final_meta_game"] != 0
):
- master_config = _mix_configs_policies(
+ master_config = cross_play.utils.mix_policies_in_given_rllib_configs(
all_rllib_configs,
n_mix_per_config=hp_eval["n_cross_play_in_final_meta_game"],
)
@@ -236,100 +288,15 @@ def _mix_rllib_config(all_rllib_configs, hp_eval):
raise ValueError()
-def _mix_configs_policies(all_rllib_configs, n_mix_per_config):
- assert n_mix_per_config <= len(all_rllib_configs) - 1
- policy_ids = all_rllib_configs[0]["multiagent"]["policies"].keys()
- assert len(policy_ids) == 2
- _assert_all_config_use_the_same_policies(all_rllib_configs, policy_ids)
-
- policy_config_variants = _gather_policy_variant_per_policy_id(
- all_rllib_configs, policy_ids
- )
-
- master_config = _create_one_master_config(
- all_rllib_configs, policy_config_variants, policy_ids, n_mix_per_config
- )
- return master_config
-
-
-def _create_one_master_config(
- all_rllib_configs, policy_config_variants, policy_ids, n_mix_per_config
-):
- all_policy_mix = []
- player_1, player_2 = policy_ids
- for config_idx, p1_policy_config in enumerate(
- policy_config_variants[player_1]
- ):
- policies_mixes = _produce_n_mix_with_player_2_policies(
- policy_config_variants,
- player_2,
- config_idx,
- n_mix_per_config,
- player_1,
- p1_policy_config,
- )
- all_policy_mix.extend(policies_mixes)
-
- master_config = copy.deepcopy(all_rllib_configs[0])
- master_config["multiagent"]["policies"] = tune.grid_search(all_policy_mix)
- return master_config
-
-
-def _produce_n_mix_with_player_2_policies(
- policy_config_variants,
- player_2,
- config_idx,
- n_mix_per_config,
- player_1,
- p1_policy_config,
-):
- p2_policy_configs_sampled = _get_p2_policies_samples_excluding_self(
- policy_config_variants, player_2, config_idx, n_mix_per_config
- )
- policies_mixes = []
- for p2_policy_config in p2_policy_configs_sampled:
- policy_mix = {
- player_1: p1_policy_config,
- player_2: p2_policy_config,
- }
- policies_mixes.append(policy_mix)
- return policies_mixes
-
-
-def _get_p2_policies_samples_excluding_self(
- policy_config_variants, player_2, config_idx, n_mix_per_config
-):
- p2_policy_config_variants = copy.deepcopy(policy_config_variants[player_2])
- p2_policy_config_variants.pop(config_idx)
- p2_policy_configs_sampled = random.sample(
- p2_policy_config_variants, n_mix_per_config
- )
- return p2_policy_configs_sampled
-
-
-def _assert_all_config_use_the_same_policies(all_rllib_configs, policy_ids):
- for rllib_config in all_rllib_configs:
- assert rllib_config["multiagent"]["policies"].keys() == policy_ids
-
-
-def _gather_policy_variant_per_policy_id(all_rllib_configs, policy_ids):
- policy_config_variants = {}
- for policy_id in policy_ids:
- policy_config_variants[policy_id] = []
- for rllib_config in all_rllib_configs:
- policy_config_variants[policy_id].append(
- copy.deepcopy(
- rllib_config["multiagent"]["policies"][policy_id]
- )
- )
- return policy_config_variants
-
-
-def _extract_metrics(tune_analysis, hp_eval):
- player_1_payoffs = miscellaneous.extract_metric_values_per_trials(
+def extract_metrics(tune_analysis, hp_eval):
+ # TODO PROBLEM using last result but there are several episodes with
+ # different welfare functions !! => BUT plot is good
+ # Metric from the last result means the metric average over the last
+ # training iteration and we have only one training iteration here.
+ player_1_payoffs = utils.tune_analysis.extract_metrics_for_each_trials(
tune_analysis, metric=hp_eval["x_axis_metric"]
)
- player_2_payoffs = miscellaneous.extract_metric_values_per_trials(
+ player_2_payoffs = utils.tune_analysis.extract_metrics_for_each_trials(
tune_analysis, metric=hp_eval["y_axis_metric"]
)
mean_player_1_payoffs = sum(player_1_payoffs) / len(player_1_payoffs)
@@ -339,33 +306,50 @@ def _extract_metrics(tune_analysis, hp_eval):
mean_player_1_payoffs,
mean_player_2_payoffs,
)
- return mean_player_1_payoffs, mean_player_2_payoffs
+ return (
+ mean_player_1_payoffs,
+ mean_player_2_payoffs,
+ player_1_payoffs,
+ player_2_payoffs,
+ )
-def _save_to_json(exp_name, object):
- exp_dir = os.path.join("~/ray_results", exp_name)
- exp_dir = os.path.expanduser(exp_dir)
- json_file = os.path.join(exp_dir, "final_eval_in_base_game.json")
+class NumpyEncoder(json.JSONEncoder):
+ """ Special json encoder for numpy types """
+
+ def default(self, obj):
+ if isinstance(obj, np.integer):
+ return int(obj)
+ elif isinstance(obj, np.floating):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ return json.JSONEncoder.default(self, obj)
+
+
+def save_to_json(exp_name, object, filename="final_eval_in_base_game.json"):
+ exp_dir = _get_exp_dir_from_exp_name(exp_name)
+ json_file = os.path.join(exp_dir, filename)
+ if not os.path.exists(exp_dir):
+ os.makedirs(exp_dir)
with open(json_file, "w") as outfile:
json.dump(object, outfile)
-def _plot_results(exp_name, results, hp_eval):
- exp_dir = os.path.join("~/ray_results", exp_name)
- exp_dir = os.path.expanduser(exp_dir)
+def plot_results(exp_name, results, hp_eval, format_fn, jitter=0.0):
+ exp_dir = _get_exp_dir_from_exp_name(exp_name)
+ data_groups_per_mode = format_fn(results)
- data_groups_per_mode = _format_result_for_plotting(results)
+ background_area_coord = None
+ if "CoinGame" not in hp_eval["env_name"] and "env_class" in hp_eval.keys():
+ background_area_coord = hp_eval["env_class"].PAYOFF_MATRIX
- if "CoinGame" in hp_eval["env_name"]:
- background_area_coord = None
- else:
- background_area_coord = hp_eval["env_class"].PAYOUT_MATRIX
plot_config = plot.PlotConfig(
save_dir_path=exp_dir,
xlim=hp_eval["x_limits"],
ylim=hp_eval["y_limits"],
markersize=5,
- jitter=hp_eval["jitter"],
+ jitter=jitter,
xlabel="player 1 payoffs",
ylabel="player 2 payoffs",
x_scale_multiplier=hp_eval["plot_axis_scale_multipliers"][0],
@@ -377,12 +361,9 @@ def _plot_results(exp_name, results, hp_eval):
plot_helper.plot_dots(data_groups_per_mode)
-def _format_result_for_plotting(results):
+def format_result_for_plotting(results):
data_groups_per_mode = {}
- for (
- tau,
- (mean_player_1_payoffs, mean_player_2_payoffs),
- ) in results:
+ for (tau, (mean_player_1_payoffs, mean_player_2_payoffs), _) in results:
df_row_dict = {}
df_row_dict[f"mean"] = (
mean_player_1_payoffs,
@@ -394,7 +375,7 @@ def _format_result_for_plotting(results):
def test_meta_solver(debug):
- hp = _get_hyperparameters(debug)
+ hp = get_hyperparameters(debug)
hp = _load_base_game_results(hp, load_base_replicate_i=0)
stop, env_config, rllib_config = amtft_various_env.get_rllib_config(
hp, welfare_fn=postprocessing.WELFARE_INEQUITY_AVERSION, eval=True
@@ -422,8 +403,10 @@ def test_meta_solver(debug):
)
)
- all_welfare_pairs_wt_payoffs = _get_all_welfare_pairs_wt_payoffs(
- hp, env_config["players_ids"]
+ all_welfare_pairs_wt_payoffs = (
+ _get_all_welfare_pairs_wt_cross_play_payoffs(
+ hp, env_config["players_ids"]
+ )
)
print("all_welfare_pairs_wt_payoffs", all_welfare_pairs_wt_payoffs)
player_meta_policy.setup_meta_game(
@@ -469,120 +452,96 @@ def _get_results_json_file_path_in_dir(data_dir):
def _eval_dir(data_dir):
eval_dir_parents_2 = os.path.join(data_dir, "eval")
- eval_dir_parents_1 = _get_unique_child_dir(eval_dir_parents_2)
- eval_dir = _get_unique_child_dir(eval_dir_parents_1)
+ eval_dir_parents_1 = path.get_unique_child_dir(eval_dir_parents_2)
+ eval_dir = path.get_unique_child_dir(eval_dir_parents_1)
return eval_dir
def _get_json_results_path(eval_dir):
- all_files_filtered = _get_children_paths_wt_selecting_filter(
- eval_dir, _filter=self_and_cross_perf.RESULTS_SUMMARY_FILENAME_PREFIX
+ all_files_filtered = path.get_children_paths_wt_selecting_filter(
+ eval_dir,
+ _filter=cross_play.evaluator.RESULTS_SUMMARY_FILENAME_PREFIX,
)
all_files_filtered = [
file_path for file_path in all_files_filtered if ".json" in file_path
]
+ all_files_filtered = [
+ file_path
+ for file_path in all_files_filtered
+ if os.path.split(file_path)[-1].startswith("plotself")
+ or not os.path.split(file_path)[-1].startswith("plot")
+ ]
assert len(all_files_filtered) == 1, f"{all_files_filtered}"
json_result_path = os.path.join(eval_dir, all_files_filtered[0])
return json_result_path
-def _get_unique_child_dir(_dir):
- list_dir = os.listdir(_dir)
- assert len(list_dir) == 1, f"{list_dir}"
- unique_child_dir = os.path.join(_dir, list_dir[0])
- return unique_child_dir
-
-
def _get_checkpoints_for_each_welfare_in_dir(data_dir, hp):
+ """Get the checkpoints of the base game policies in self-play"""
ckpt_per_welfare = {}
for welfare_fn, welfare_name in hp["welfare_functions"]:
welfare_training_save_dir = os.path.join(data_dir, welfare_fn, "coop")
- all_replicates_save_dir = _get_dir_of_each_replicate(
- welfare_training_save_dir
+ all_replicates_save_dir = _get_replicates_dir_for_all_possible_config(
+ hp, welfare_training_save_dir
)
- ckpts = _get_checkpoint_for_each_replicates(all_replicates_save_dir)
+ all_replicates_save_dir = _filter_checkpoints_if_utilitarians(
+ all_replicates_save_dir, hp
+ )
+ ckpts = restore.get_checkpoint_for_each_replicates(
+ all_replicates_save_dir
+ )
+ print("len(ckpts)", len(ckpts))
ckpt_per_welfare[welfare_name.replace("_", " ")] = ckpts
return ckpt_per_welfare
-def _get_dir_of_each_replicate(welfare_training_save_dir):
- return _get_children_paths_wt_selecting_filter(
- welfare_training_save_dir, _filter="DQN_"
- )
-
-
-def _get_checkpoint_for_each_replicates(all_replicates_save_dir):
- ckpt_dir_per_replicate = []
- for replicate_dir_path in all_replicates_save_dir:
- ckpt_dir_path = _get_ckpt_dir_for_one_replicate(replicate_dir_path)
- ckpt_path = _get_ckpt_fro_ckpt_dir(ckpt_dir_path)
- ckpt_dir_per_replicate.append(ckpt_path)
- return ckpt_dir_per_replicate
-
-
-def _get_ckpt_dir_for_one_replicate(replicate_dir_path):
- partialy_filtered_ckpt_dir = _get_children_paths_wt_selecting_filter(
- replicate_dir_path, _filter="checkpoint_"
- )
- ckpt_dir = [
- file_path
- for file_path in partialy_filtered_ckpt_dir
- if ".is_checkpoint" not in file_path
- ]
- assert len(ckpt_dir) == 1, f"{ckpt_dir}"
- return ckpt_dir[0]
+def _get_replicates_dir_for_all_possible_config(hp, welfare_training_save_dir):
+ if "use_r2d2" in hp and hp["use_r2d2"]:
+ all_replicates_save_dir = get_dir_of_each_replicate(
+ welfare_training_save_dir, str_in_dir="R2D2_"
+ )
+ else:
+ all_replicates_save_dir = get_dir_of_each_replicate(
+ welfare_training_save_dir, str_in_dir="DQN_"
+ )
+ return all_replicates_save_dir
-def _get_ckpt_fro_ckpt_dir(ckpt_dir_path):
- partialy_filtered_ckpt_path = _get_children_paths_wt_discarding_filter(
- ckpt_dir_path, _filter="tune_metadata"
- )
- ckpt_path = [
- file_path
- for file_path in partialy_filtered_ckpt_path
- if ".is_checkpoint" not in file_path
+def _filter_checkpoints_if_utilitarians(all_replicates_save_dir, hp):
+ util_or_inequity_aversion = [
+ path.split("/")[-3] for path in all_replicates_save_dir
]
- assert len(ckpt_path) == 1, f"{ckpt_path}"
- return ckpt_path[0]
-
-
-def _get_children_paths_wt_selecting_filter(parent_dir_path, _filter):
- return _get_children_paths_filters(
- parent_dir_path, selecting_filter=_filter
- )
+ print("util_or_inequity_aversion", util_or_inequity_aversion)
+ if all(
+ welfare == "utilitarian_welfare"
+ for welfare in util_or_inequity_aversion
+ ):
+ all_replicates_save_dir = (
+ utils.path.filter_list_of_replicates_by_results(
+ all_replicates_save_dir,
+ filter_key="episode_reward_mean",
+ filter_mode=tune_analysis.ABOVE,
+ filter_threshold=95,
+ )
+ )
+ if len(all_replicates_save_dir) > hp["max_n_replicates_in_base_game"]:
+ all_replicates_save_dir = all_replicates_save_dir[
+ : hp["max_n_replicates_in_base_game"]
+ ]
+ else:
+ print(
+ "not filtering ckpts. all_replicates_save_dir[0]",
+ all_replicates_save_dir[0],
+ )
+ return all_replicates_save_dir
-def _get_children_paths_wt_discarding_filter(parent_dir_path, _filter):
- return _get_children_paths_filters(
- parent_dir_path, discarding_filter=_filter
+def get_dir_of_each_replicate(welfare_training_save_dir, str_in_dir="DQN_"):
+ return path.get_children_paths_wt_selecting_filter(
+ welfare_training_save_dir, _filter=str_in_dir
)
-def _get_children_paths_filters(
- parent_dir_path: str,
- selecting_filter: str = None,
- discarding_filter: str = None,
-):
- filtered_children = os.listdir(parent_dir_path)
- if selecting_filter is not None:
- filtered_children = [
- filename
- for filename in filtered_children
- if selecting_filter in filename
- ]
- if discarding_filter is not None:
- filtered_children = [
- filename
- for filename in filtered_children
- if discarding_filter not in filename
- ]
- filtered_children_path = [
- os.path.join(parent_dir_path, filename)
- for filename in filtered_children
- ]
- return filtered_children_path
-
-
def _convert_policy_config_to_meta_policy_config(
hp, policy_id, policy_config, policy_class
):
@@ -601,18 +560,18 @@ def _convert_policy_config_to_meta_policy_config(
return meta_policy_config
-def _get_all_welfare_pairs_wt_payoffs(hp, player_ids):
+def _get_all_welfare_pairs_wt_cross_play_payoffs(hp, player_ids):
with open(hp["json_file"]) as json_file:
json_data = json.load(json_file)
- cross_play_data = _keep_only_cross_play_values(json_data)
+ cross_play_data = keep_only_cross_play_values(json_data)
cross_play_means = _keep_only_mean_values(cross_play_data)
all_welfare_pairs_wt_payoffs = _order_players(cross_play_means, player_ids)
return all_welfare_pairs_wt_payoffs
-def _keep_only_cross_play_values(json_data):
+def keep_only_cross_play_values(json_data):
return {
_format_eval_mode(eval_mode): v
for eval_mode, v in json_data.items()
@@ -623,7 +582,7 @@ def _keep_only_cross_play_values(json_data):
def _format_eval_mode(eval_mode):
k_wtout_kind_of_play = eval_mode.split(":")[-1].strip()
both_welfare_fn = k_wtout_kind_of_play.split(" vs ")
- return welfare_coordination.WelfareCoordinationTorchPolicy._from_pair_of_welfare_names_to_key(
+ return welfare_coordination.WelfareCoordinationTorchPolicy.from_pair_of_welfare_names_to_key(
*both_welfare_fn
)
@@ -648,6 +607,118 @@ def _order_players(cross_play_means, player_ids):
}
+def extract_stats_on_welfare_announced(
+ players_ids, exp_name=None, exp_dir=None, nested_info=False
+):
+ if exp_dir is None:
+ exp_dir = _get_exp_dir_from_exp_name(exp_name)
+ all_in_exp_dir = utils.path.get_children_paths_wt_discarding_filter(
+ exp_dir, _filter=None
+ )
+ all_dirs = utils.path.keep_dirs_only(all_in_exp_dir)
+ dir_name_by_tau = _group_dir_name_by_tau(all_dirs, nested_info)
+ dir_name_by_tau = _order_by_tau(dir_name_by_tau)
+ _get_stats_for_each_tau(dir_name_by_tau, players_ids, exp_dir)
+
+
+def _get_exp_dir_from_exp_name(exp_name: str):
+ exp_dir = os.path.join("~/ray_results", exp_name)
+ exp_dir = os.path.expanduser(exp_dir)
+ return exp_dir
+
+
+def _group_dir_name_by_tau(all_dirs: List[str], nested_info) -> Dict:
+ dirs_by_tau = {}
+ for trial_dir in all_dirs:
+ tau, welfare_set_annonced = _get_tau_value(trial_dir, nested_info)
+ if tau is None:
+ continue
+ if tau not in dirs_by_tau.keys():
+ dirs_by_tau[tau] = []
+ dirs_by_tau[tau].append((trial_dir, welfare_set_annonced))
+ return dirs_by_tau
+
+
+def _get_tau_value(trial_dir_path, nested_info=False):
+ full_epi_file_path = os.path.join(
+ trial_dir_path, "full_episodes_logs.json"
+ )
+ if os.path.exists(full_epi_file_path):
+ full_epi_logs = utils.path._read_all_lines_of_file(full_epi_file_path)
+ first_epi_first_step = json.loads(full_epi_logs[1])
+ tau = []
+ welfare_set_annonced = {}
+ for policy_id, policy_info in first_epi_first_step.items():
+ print('policy_info["info"]', policy_info["info"])
+ if nested_info:
+ meta_policy_info = policy_info["info"][
+ f"meta_policy/{policy_id}"
+ ]
+ else:
+ meta_policy_info = policy_info["info"][f"meta_policy"]
+ for k, v in meta_policy_info.items():
+ if k.startswith("tau_"):
+ tau_value = float(k.split("_")[-1])
+ tau.append(tau_value)
+ welfare_set_annonced[policy_id] = v["welfare_set_annonced"]
+ tau = set(tau)
+ assert len(tau) == 1, f"tau {tau}"
+ tau = list(tau)[0]
+ assert len(welfare_set_annonced.keys()) == 2
+ return tau, welfare_set_annonced
+ else:
+ return None, None
+
+
+def _order_by_tau(dir_name_by_tau):
+ return collections.OrderedDict(sorted(dir_name_by_tau.items()))
+
+
+def _get_stats_for_each_tau(dir_name_by_tau, players_ids, exp_dir):
+ file_path = os.path.join(exp_dir, "welfare_announced_by_tau.txt")
+ with open(file_path, "w") as f:
+ for tau, dirs_data in dir_name_by_tau.items():
+ all_welfares_player_1 = []
+ all_welfares_player_2 = []
+ for dir, welfare_announced in dirs_data:
+ all_welfares_player_1.append(welfare_announced[players_ids[0]])
+ all_welfares_player_2.append(welfare_announced[players_ids[1]])
+ all_welfares_player_1 = _format_in_same_order(
+ all_welfares_player_1
+ )
+ all_welfares_player_2 = _format_in_same_order(
+ all_welfares_player_2
+ )
+ count_announced_p1 = collections.Counter(all_welfares_player_1)
+ count_announced_p2 = collections.Counter(all_welfares_player_2)
+ msg = (
+ f"===== Welfare sets announced with tau = {tau} =====\n"
+ f"Player 1: {count_announced_p1}\n"
+ f"Player 2: {count_announced_p2}\n"
+ )
+ print(msg)
+ f.write(msg)
+
+
+def _format_in_same_order(all_welfares_player_n):
+ formatted_welfare_announced = []
+ for welfare_announced in all_welfares_player_n:
+ formated_name = ""
+ if "utilitarian" in welfare_announced:
+ formated_name += "utilitarian + "
+ if (
+ "inequity" in welfare_announced
+ or "egalitarian" in welfare_announced
+ ):
+ formated_name += "egalitarian + "
+ if "mixed" in welfare_announced:
+ formated_name += "mixed"
+ if formated_name.endswith(" + "):
+ formated_name = formated_name[:-3]
+ formatted_welfare_announced.append(formated_name)
+ return formatted_welfare_announced
+
+
if __name__ == "__main__":
- debug_mode = False
+ debug_mode = True
main(debug_mode)
diff --git a/marltoolbox/experiments/rllib_api/amtft_various_env.py b/marltoolbox/experiments/rllib_api/amtft_various_env.py
index f6900c8..83e7042 100644
--- a/marltoolbox/experiments/rllib_api/amtft_various_env.py
+++ b/marltoolbox/experiments/rllib_api/amtft_various_env.py
@@ -5,17 +5,17 @@
import ray
from ray import tune
from ray.rllib.agents import dqn
-from ray.rllib.agents.dqn.dqn_torch_policy import postprocess_nstep_and_prio
+from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.schedules import PiecewiseSchedule
-from ray.tune.integration.wandb import WandbLogger
-from ray.tune.logger import DEFAULT_LOGGERS
+from ray.tune.integration.wandb import WandbLoggerCallback
-from marltoolbox.algos import amTFT
+from marltoolbox import utils
+from marltoolbox.algos import amTFT, augmented_r2d2
from marltoolbox.envs import (
matrix_sequential_social_dilemma,
- vectorized_coin_game,
- vectorized_mixed_motive_coin_game,
+ coin_game,
+ mixed_motive_coin_game,
ssd_mixed_motive_coin_game,
)
from marltoolbox.envs.utils.wrappers import (
@@ -28,21 +28,30 @@
postprocessing,
miscellaneous,
plot,
- self_and_cross_perf,
callbacks,
+ cross_play,
+ config_helper,
)
logger = logging.getLogger(__name__)
-def main(debug, train_n_replicates=None, filter_utilitarian=None, env=None):
+def main(
+ debug,
+ train_n_replicates=None,
+ filter_utilitarian=None,
+ env=None,
+ use_r2d2=False,
+):
hparams = get_hyperparameters(
- debug, train_n_replicates, filter_utilitarian, env
+ debug, train_n_replicates, filter_utilitarian, env, use_r2d2=use_r2d2
)
if hparams["load_plot_data"] is None:
ray.init(
- num_cpus=os.cpu_count(), num_gpus=0, local_mode=hparams["debug"]
+ num_gpus=0,
+ num_cpus=os.cpu_count(),
+ local_mode=hparams["debug"],
)
# Train
@@ -76,13 +85,14 @@ def get_hyperparameters(
filter_utilitarian=None,
env=None,
reward_uncertainty=0.0,
+ use_r2d2=False,
):
if debug:
train_n_replicates = 2
n_times_more_utilitarians_seeds = 1
elif train_n_replicates is None:
n_times_more_utilitarians_seeds = 4
- train_n_replicates = 4
+ train_n_replicates = 10
else:
n_times_more_utilitarians_seeds = 4
@@ -91,8 +101,31 @@ def get_hyperparameters(
)
pool_of_seeds = miscellaneous.get_random_seeds(n_seeds_to_prepare)
exp_name, _ = log.log_in_current_day_dir("amTFT")
+ # prefix = "/home/maxime/dev-maxime/CLR/vm-data/instance-10-cpu-1/amTFT/2021_04_29/12_16_29/"
+ prefix = (
+ "/home/maxime/dev-maxime/CLR/vm-data/instance-10-cpu-1/amTFT"
+ "/2021_04_30/22_41_32/"
+ )
+ # prefix = "~/ray_results/amTFT/2021_04_29/12_16_29/"
+ util_prefix = os.path.join(prefix, "utilitarian_welfare/coop")
+ inequity_aversion_prefix = os.path.join(
+ prefix, "inequity_aversion_welfare/coop"
+ )
+ util_load_data_list = [
+ "R2D2_AsymVectorizedCoinGame_05122_00000_0_buffer_size=400000,temperature_schedule= "torch.optim.Optimizer":
@@ -359,7 +396,7 @@ def sgd_optimizer_dqn(policy, config) -> "torch.optim.Optimizer":
MyDQNTorchPolicy = DQNTorchPolicy.with_updates(
stats_fn=log.augment_stats_fn_wt_additionnal_logs(build_q_stats),
optimizer_fn=sgd_optimizer_dqn,
- after_init=after_init_fn,
+ before_loss_init=before_loss_init_fn,
)
if tune_config["env_class"] in (
diff --git a/marltoolbox/experiments/tune_class_api/lola_dice_official.py b/marltoolbox/experiments/tune_class_api/lola_dice_official.py
index ee23242..a072cd1 100644
--- a/marltoolbox/experiments/tune_class_api/lola_dice_official.py
+++ b/marltoolbox/experiments/tune_class_api/lola_dice_official.py
@@ -10,6 +10,7 @@
import ray
from ray import tune
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy
+from ray.tune.integration.wandb import WandbLoggerCallback
from marltoolbox.algos.lola_dice.train_tune_class_API import LOLADICE
from marltoolbox.envs.coin_game import CoinGame, AsymCoinGame
@@ -36,6 +37,13 @@ def main(debug):
# Example: "load_plot_data": ".../SameAndCrossPlay_save.p",
"exp_name": exp_name,
"train_n_replicates": train_n_replicates,
+ "wandb": {
+ "project": "LOLA_DICE",
+ "group": exp_name,
+ "api_key_file": os.path.join(
+ os.path.dirname(__file__), "../../../api_key_wandb"
+ ),
+ },
"env_name": "IPD",
# "env_name": "IMP",
# "env_name": "AsymBoS",
@@ -85,6 +93,16 @@ def train(hp):
stop=stop,
metric=hp["metric"],
mode="max",
+ callbacks=None
+ if hp["debug"]
+ else [
+ WandbLoggerCallback(
+ project=hp["wandb"]["project"],
+ group=hp["wandb"]["group"],
+ api_key_file=hp["wandb"]["api_key_file"],
+ log_config=True,
+ )
+ ],
)
tune_analysis_per_exp = {"": tune_analysis}
return tune_analysis_per_exp
diff --git a/marltoolbox/experiments/tune_class_api/lola_exact_meta_game.py b/marltoolbox/experiments/tune_class_api/lola_exact_meta_game.py
new file mode 100644
index 0000000..0a7acaa
--- /dev/null
+++ b/marltoolbox/experiments/tune_class_api/lola_exact_meta_game.py
@@ -0,0 +1,484 @@
+##########
+# Additional dependencies are needed:
+# Follow the LOLA installation described in the
+# tune_class_api/lola_pg_official.py file
+##########
+
+import copy
+import logging
+import os
+
+import numpy as np
+import ray
+from ray import tune
+from ray.rllib.agents.pg import PGTrainer
+
+from marltoolbox import utils
+from marltoolbox.algos import welfare_coordination
+from marltoolbox.experiments.rllib_api import amtft_meta_game
+from marltoolbox.experiments.tune_class_api import lola_exact_official
+from marltoolbox.utils import (
+ cross_play,
+ restore,
+ path,
+ callbacks,
+ log,
+ miscellaneous,
+)
+
+logger = logging.getLogger(__name__)
+
+
+EGALITARIAN = "egalitarian"
+MIXED = "mixed"
+UTILITARIAN = "utilitarian"
+FAILURE = "failure"
+
+
+def main(debug):
+ # amtft_meta_game._extract_stats_on_welfare_announced(
+ # players_ids=["player_row", "player_col"],
+ # exp_dir="/home/maxime/dev-maxime/CLR/vm-data/instance-10-cpu-2"
+ # "/LOLA_Exact/2021_04_21/17_53_53",
+ # nested_info=True,
+ # )
+
+ hp = get_hyperparameters(debug)
+
+ results = []
+ ray.init(num_cpus=os.cpu_count(), local_mode=hp["debug"])
+ for tau in hp["tau_range"]:
+ hp["tau"] = tau
+ (
+ all_rllib_config,
+ hp_eval,
+ env_config,
+ stop_config,
+ ) = _produce_rllib_config_for_each_replicates(hp)
+
+ mixed_rllib_configs = (
+ cross_play.utils.mix_policies_in_given_rllib_configs(
+ all_rllib_config, hp_eval["n_cross_play_in_final_meta_game"]
+ )
+ )
+
+ tune_analysis = ray.tune.run(
+ PGTrainer,
+ config=mixed_rllib_configs,
+ verbose=1,
+ stop=stop_config,
+ name=hp_eval["exp_name"],
+ log_to_file=not hp_eval["debug"],
+ )
+
+ (
+ mean_player_1_payoffs,
+ mean_player_2_payoffs,
+ player_1_payoffs,
+ player_2_payoffs,
+ ) = amtft_meta_game.extract_metrics(tune_analysis, hp_eval)
+
+ results.append(
+ (
+ tau,
+ (mean_player_1_payoffs, mean_player_2_payoffs),
+ (player_1_payoffs, player_2_payoffs),
+ )
+ )
+ amtft_meta_game.save_to_json(exp_name=hp["exp_name"], object=results)
+ amtft_meta_game.plot_results(
+ exp_name=hp["exp_name"],
+ results=results,
+ hp_eval=hp_eval,
+ format_fn=amtft_meta_game.format_result_for_plotting,
+ )
+ amtft_meta_game.extract_stats_on_welfare_announced(
+ players_ids=env_config["players_ids"],
+ exp_name=hp["exp_name"],
+ nested_info=True,
+ )
+
+
+def get_hyperparameters(debug):
+ """Get hyperparameters for meta game with LOLA-Exact policies in base
+ game"""
+ # env = "IPD"
+ env = "IteratedAsymBoS"
+
+ hp = lola_exact_official.get_hyperparameters(
+ debug, train_n_replicates=1, env=env
+ )
+
+ hp.update(
+ {
+ "n_replicates_over_full_exp": 2 if debug else 20,
+ "final_base_game_eval_over_n_epi": 1 if debug else 200,
+ "tau_range": np.arange(0.0, 1.1, 0.5)
+ if hp["debug"]
+ else np.arange(0.0, 1.1, 0.1),
+ "n_self_play_in_final_meta_game": 0,
+ "n_cross_play_in_final_meta_game": 1 if debug else 10,
+ "welfare_functions": [
+ (EGALITARIAN, EGALITARIAN),
+ (MIXED, MIXED),
+ (UTILITARIAN, UTILITARIAN),
+ ],
+ }
+ )
+ return hp
+
+
+def _produce_rllib_config_for_each_replicates(hp):
+ all_rllib_config = []
+ for replicate_i in range(hp["n_replicates_over_full_exp"]):
+ hp_eval = _load_base_game_results(
+ copy.deepcopy(hp), load_base_replicate_i=replicate_i
+ )
+
+ (
+ rllib_config,
+ hp_eval,
+ env_config,
+ stop_config,
+ ) = _get_vanilla_lola_exact_eval_config(
+ hp_eval, hp_eval["final_base_game_eval_over_n_epi"]
+ )
+
+ rllib_config = _modify_config_to_use_welfare_coordinators(
+ rllib_config, env_config, hp_eval
+ )
+ all_rllib_config.append(rllib_config)
+ return all_rllib_config, hp_eval, env_config, stop_config
+
+
+def _load_base_game_results(hp, load_base_replicate_i):
+
+ # In local machine
+ # prefix = "~/dev-maxime/CLR/vm-data/instance-10-cpu-2/"
+ # prefix = "~/dev-maxime/CLR/vm-data/instance-10-cpu-2/"
+ # prefix = "~/dev-maxime/CLR/vm-data/instance-60-cpu-2-preemtible/"
+ prefix = "~/dev-maxime/CLR/vm-data/instance-60-cpu-3-preemtible/"
+ prefix2 = "~/dev-maxime/CLR/vm-data/instance-60-cpu-4-preemtible/"
+
+ # In VM
+ # prefix = "~/ray_results/"
+ # prefix2 = prefix
+
+ prefix = os.path.expanduser(prefix)
+ prefix2 = os.path.expanduser(prefix2)
+ if "IteratedAsymBoS" in hp["env_name"]:
+ hp["data_dir"] = (
+ # instance-60-cpu-3-preemtible & instance-60-cpu-4-preemtible
+ prefix + "LOLA_Exact/2021_05_05/14_49_18", # 30 replicates
+ prefix + "LOLA_Exact/2021_05_05/14_50_39", # 30 replicates
+ prefix + "LOLA_Exact/2021_05_05/14_51_01", # 30 replicates
+ prefix + "LOLA_Exact/2021_05_05/14_53_56", # 30 replicates
+ prefix + "LOLA_Exact/2021_05_05/14_56_32", # 30 replicates
+ prefix + "LOLA_Exact/2021_05_05/15_46_08", # 30 replicates
+ prefix + "LOLA_Exact/2021_05_05/15_46_23", # 30 replicates
+ prefix + "LOLA_Exact/2021_05_05/15_46_59", # 30 replicates
+ prefix + "LOLA_Exact/2021_05_05/15_47_22", # 30 replicates
+ prefix + "LOLA_Exact/2021_05_05/15_48_22", # 30 replicates
+ # instance-60-cpu-4-preemtible
+ prefix2 + "LOLA_Exact/2021_05_07/07_52_32", # 30 replicates
+ prefix2 + "LOLA_Exact/2021_05_07/08_02_38", # 30 replicates
+ prefix2 + "LOLA_Exact/2021_05_07/08_02_49", # 30 replicates
+ prefix2 + "LOLA_Exact/2021_05_07/08_03_03", # 30 replicates
+ prefix2 + "LOLA_Exact/2021_05_07/08_54_58", # 30 replicates
+ prefix2 + "LOLA_Exact/2021_05_07/08_55_34", # 30 replicates
+ prefix2 + "LOLA_Exact/2021_05_07/09_04_07", # 30 replicates
+ prefix2 + "LOLA_Exact/2021_05_07/09_09_30", # 30 replicates
+ prefix2 + "LOLA_Exact/2021_05_07/09_09_42", # 30 replicates
+ prefix2 + "LOLA_Exact/2021_05_07/10_02_15", # 30 replicates
+ prefix2 + "LOLA_Exact/2021_05_07/10_02_30", # 30 replicates
+ prefix2 + "LOLA_Exact/2021_05_07/10_02_39", # 30 replicates
+ prefix2 + "LOLA_Exact/2021_05_07/10_02_50", # 30 replicates
+ )[load_base_replicate_i]
+ else:
+ raise ValueError(f'bad env_name: {hp["env_name"]}')
+
+ assert os.path.exists(hp["data_dir"]), (
+ "Path doesn't exist. Probably that the prefix need to "
+ f"be changed to fit the current machine used. path: {hp['data_dir']}"
+ )
+
+ print("==== Going to process data_dir", hp["data_dir"], "====")
+
+ hp["ckpt_per_welfare"] = _get_checkpoints_for_each_welfare_in_dir(
+ hp["data_dir"], hp
+ )
+
+ return hp
+
+
+def _get_checkpoints_for_each_welfare_in_dir(data_dir, hp):
+ all_replicates_save_dir = amtft_meta_game.get_dir_of_each_replicate(
+ data_dir, str_in_dir="LOLAExactTrainer_"
+ )
+ assert len(all_replicates_save_dir) > 0
+ welfares = _classify_base_replicates_into_welfares(all_replicates_save_dir)
+
+ ckpt_per_welfare = {}
+ for welfare_fn, welfare_name in hp["welfare_functions"]:
+ replicates_save_dir_for_welfare = _filter_replicate_dir_by_welfare(
+ all_replicates_save_dir, welfares, welfare_name
+ )
+ ckpts = restore.get_checkpoint_for_each_replicates(
+ replicates_save_dir_for_welfare
+ )
+ ckpt_per_welfare[welfare_name] = [ckpt + ".json" for ckpt in ckpts]
+ return ckpt_per_welfare
+
+
+def _classify_base_replicates_into_welfares(all_replicates_save_dir):
+ welfares = []
+ for replicate_dir in all_replicates_save_dir:
+ reward_player_1, reward_player_2 = _get_last_episode_rewards(
+ replicate_dir
+ )
+ welfare_name = classify_into_welfare_based_on_rewards(
+ reward_player_1, reward_player_2
+ )
+ welfares.append(welfare_name)
+ return welfares
+
+
+def classify_into_welfare_based_on_rewards(reward_player_1, reward_player_2):
+
+ ratio = reward_player_1 / reward_player_2
+ if ratio < 1.5:
+ return EGALITARIAN
+ elif ratio < 2.5:
+ return MIXED
+ else:
+ return UTILITARIAN
+
+
+def _filter_replicate_dir_by_welfare(
+ all_replicates_save_dir, welfares, welfare_name
+):
+ replicates_save_dir_for_welfare = [
+ replicate_dir
+ for welfare, replicate_dir in zip(welfares, all_replicates_save_dir)
+ if welfare == welfare_name
+ ]
+ return replicates_save_dir_for_welfare
+
+
+def _get_last_episode_rewards(replicate_dir):
+ results = utils.path.get_results_for_replicate(replicate_dir)
+ last_epsiode_results = results[-1]
+ return last_epsiode_results["ret1"], last_epsiode_results["ret2"]
+
+
+def _get_vanilla_lola_exact_eval_config(hp, final_eval_over_n_epi):
+ (
+ hp_eval,
+ rllib_config,
+ policies_to_load,
+ trainable_class,
+ stop_config,
+ env_config,
+ ) = lola_exact_official.generate_eval_config(hp)
+
+ hp_eval["n_self_play_per_checkpoint"] = None
+ hp_eval["n_cross_play_per_checkpoint"] = None
+ hp_eval[
+ "x_axis_metric"
+ ] = f"policy_reward_mean/{env_config['players_ids'][0]}"
+ hp_eval[
+ "y_axis_metric"
+ ] = f"policy_reward_mean/{env_config['players_ids'][1]}"
+ hp_eval["plot_axis_scale_multipliers"] = (
+ 1 / hp_eval["trace_length"],
+ 1 / hp_eval["trace_length"],
+ )
+ hp_eval["num_episodes"] = final_eval_over_n_epi
+ stop_config["episodes_total"] = final_eval_over_n_epi
+ rllib_config["callbacks"] = callbacks.merge_callbacks(
+ callbacks.PolicyCallbacks,
+ log.get_logging_callbacks_class(
+ log_full_epi=True,
+ log_full_epi_interval=1,
+ log_from_policy_in_evaluation=True,
+ ),
+ )
+ rllib_config["seed"] = miscellaneous.get_random_seeds(1)[0]
+ rllib_config["log_level"] = "INFO"
+
+ return rllib_config, hp_eval, env_config, stop_config
+
+
+def _modify_config_to_use_welfare_coordinators(
+ rllib_config, env_config, hp_eval
+):
+ all_welfare_pairs_wt_payoffs = (
+ _get_all_welfare_pairs_wt_cross_play_payoffs(
+ hp_eval, env_config["players_ids"]
+ )
+ )
+
+ rllib_config["multiagent"]["policies_to_train"] = ["None"]
+ policies = rllib_config["multiagent"]["policies"]
+ for policy_idx, policy_id in enumerate(env_config["players_ids"]):
+ policy_config_items = list(policies[policy_id])
+ opp_policy_idx = (policy_idx + 1) % 2
+
+ meta_policy_config = copy.deepcopy(welfare_coordination.DEFAULT_CONFIG)
+ meta_policy_config.update(
+ {
+ "nested_policies": [
+ {
+ "Policy_class": copy.deepcopy(policy_config_items[0]),
+ "config_update": copy.deepcopy(policy_config_items[3]),
+ },
+ ],
+ "solve_meta_game_after_init": True,
+ "tau": hp_eval["tau"],
+ "all_welfare_pairs_wt_payoffs": all_welfare_pairs_wt_payoffs,
+ "own_player_idx": policy_idx,
+ "opp_player_idx": opp_policy_idx,
+ "own_default_welfare_fn": EGALITARIAN
+ if policy_idx == 1
+ else UTILITARIAN,
+ "opp_default_welfare_fn": EGALITARIAN
+ if opp_policy_idx == 1
+ else UTILITARIAN,
+ "policy_id_to_load": policy_id,
+ "policy_checkpoints": hp_eval["ckpt_per_welfare"],
+ }
+ )
+ policy_config_items[
+ 0
+ ] = welfare_coordination.WelfareCoordinationTorchPolicy
+ policy_config_items[3] = meta_policy_config
+ policies[policy_id] = tuple(policy_config_items)
+
+ return rllib_config
+
+
+def _get_all_welfare_pairs_wt_cross_play_payoffs(hp, player_ids):
+ all_eval_replicates_dirs = _get_list_of_replicates_path_in_eval(hp)
+
+ raw_data_points_wt_welfares = {}
+ for eval_replicate_path in all_eval_replicates_dirs:
+ players_ckpts = _extract_checkpoints_used_for_each_players(
+ player_ids, eval_replicate_path
+ )
+ if _is_cross_play(players_ckpts):
+ players_welfares = _convert_checkpoint_names_to_welfares(
+ hp, players_ckpts
+ )
+ raw_players_perf = _extract_performance(
+ eval_replicate_path, player_ids
+ )
+ play_mode = _get_play_mode(players_welfares)
+ if play_mode not in raw_data_points_wt_welfares.keys():
+ raw_data_points_wt_welfares[play_mode] = []
+ raw_data_points_wt_welfares[play_mode].append(raw_players_perf)
+ all_welfare_pairs_wt_payoffs = _average_perf_per_play_mode(
+ raw_data_points_wt_welfares, hp
+ )
+ print("all_welfare_pairs_wt_payoffs", all_welfare_pairs_wt_payoffs)
+ return all_welfare_pairs_wt_payoffs
+
+
+def _get_list_of_replicates_path_in_eval(hp):
+ child_dirs = utils.path.get_children_paths_wt_discarding_filter(
+ hp["data_dir"], _filter="LOLAExact"
+ )
+ child_dirs = utils.path.keep_dirs_only(child_dirs)
+ assert len(child_dirs) == 1, f"{child_dirs}"
+ eval_dir = utils.path.get_unique_child_dir(child_dirs[0])
+ eval_replicates_dir = utils.path.get_unique_child_dir(eval_dir)
+ possible_nested_dir = utils.path.try_get_unique_child_dir(
+ eval_replicates_dir
+ )
+ if possible_nested_dir is not None:
+ eval_replicates_dir = possible_nested_dir
+ all_eval_replicates_dirs = (
+ utils.path.get_children_paths_wt_selecting_filter(
+ eval_replicates_dir, _filter="PG_"
+ )
+ )
+ return all_eval_replicates_dirs
+
+
+def _extract_checkpoints_used_for_each_players(
+ player_ids, eval_replicate_path
+):
+ params = utils.path.get_params_for_replicate(eval_replicate_path)
+ policies_config = params["multiagent"]["policies"]
+ ckps = [
+ policies_config[player_id][3]["checkpoint_to_load_from"][0]
+ for player_id in player_ids
+ ]
+ return ckps
+
+
+def _is_cross_play(players_ckpts):
+ return players_ckpts[0] != players_ckpts[1]
+
+
+def _convert_checkpoint_names_to_welfares(hp, players_ckpts):
+ players_welfares = []
+ for player_ckpt in players_ckpts:
+ player_ckpt_wtout_root = "/".join(player_ckpt.split("/")[-4:])
+ for welfare, ckpts_for_welfare in hp["ckpt_per_welfare"].items():
+ if any(
+ player_ckpt_wtout_root in ckpt for ckpt in ckpts_for_welfare
+ ):
+ players_welfares.append(welfare)
+ break
+
+ assert len(players_welfares) == len(
+ players_ckpts
+ ), f"{len(players_welfares)} == {len(players_ckpts)}"
+ return players_welfares
+
+
+def _extract_performance(eval_replicate_path, player_ids):
+ results_per_epi = utils.path.get_results_for_replicate(eval_replicate_path)
+ players_avg_reward = _extract_and_average_perf(results_per_epi, player_ids)
+ return players_avg_reward
+
+
+def _extract_and_average_perf(results_per_epi, player_ids):
+ players_avg_reward = []
+ for player_id in player_ids:
+ player_rewards = []
+ for result_in_one_epi in results_per_epi:
+ total_player_reward_in_one_epi = result_in_one_epi[
+ "policy_reward_mean"
+ ][player_id]
+ player_rewards.append(total_player_reward_in_one_epi)
+ players_avg_reward.append(sum(player_rewards) / len(player_rewards))
+ return players_avg_reward
+
+
+def _get_play_mode(players_welfares):
+ return f"{players_welfares[0]}-{players_welfares[1]}"
+
+
+def _average_perf_per_play_mode(raw_data_points_wt_welfares, hp):
+ all_welfare_pairs_wt_payoffs = {}
+ for (
+ play_mode,
+ values_per_replicates,
+ ) in raw_data_points_wt_welfares.items():
+ player_1_values = [
+ value_replicate[0] for value_replicate in values_per_replicates
+ ]
+ player_2_values = [
+ value_replicate[1] for value_replicate in values_per_replicates
+ ]
+ all_welfare_pairs_wt_payoffs[play_mode] = (
+ sum(player_1_values) / len(player_1_values) / hp["trace_length"],
+ sum(player_2_values) / len(player_2_values) / hp["trace_length"],
+ )
+ return all_welfare_pairs_wt_payoffs
+
+
+if __name__ == "__main__":
+ debug_mode = True
+ main(debug_mode)
diff --git a/marltoolbox/experiments/tune_class_api/lola_exact_official.py b/marltoolbox/experiments/tune_class_api/lola_exact_official.py
index 1122d30..c903377 100644
--- a/marltoolbox/experiments/tune_class_api/lola_exact_official.py
+++ b/marltoolbox/experiments/tune_class_api/lola_exact_official.py
@@ -9,9 +9,12 @@
import ray
from ray import tune
+from ray.tune.analysis import ExperimentAnalysis
from ray.rllib.agents.pg import PGTorchPolicy
+from ray.tune.integration.wandb import WandbLoggerCallback
+from marltoolbox.experiments.tune_class_api import lola_exact_meta_game
-from marltoolbox.algos.lola.train_exact_tune_class_API import LOLAExact
+from marltoolbox.algos.lola.train_exact_tune_class_API import LOLAExactTrainer
from marltoolbox.envs.matrix_sequential_social_dilemma import (
IteratedPrisonersDilemma,
IteratedMatchingPennies,
@@ -19,10 +22,31 @@
)
from marltoolbox.experiments.tune_class_api import lola_pg_official
from marltoolbox.utils import policy, log, miscellaneous
+from marltoolbox.scripts import aggregate_and_plot_tensorboard_data
def main(debug):
- train_n_replicates = 2 if debug else 40
+ hparams = get_hyperparameters(debug)
+
+ if hparams["load_plot_data"] is None:
+ ray.init(
+ num_cpus=os.cpu_count(),
+ num_gpus=0,
+ local_mode=debug,
+ )
+ tune_analysis_per_exp = train(hparams)
+ else:
+ tune_analysis_per_exp = None
+
+ evaluate(tune_analysis_per_exp, hparams)
+ ray.shutdown()
+
+
+def get_hyperparameters(debug, train_n_replicates=None, env=None):
+ """Get hyperparameters for LOLA-Exact for matrix games"""
+
+ if train_n_replicates is None:
+ train_n_replicates = 2 if debug else int(3 * 1)
seeds = miscellaneous.get_random_seeds(train_n_replicates)
exp_name, _ = log.log_in_current_day_dir("LOLA_Exact")
@@ -32,12 +56,24 @@ def main(debug):
"load_plot_data": None,
# Example "load_plot_data": ".../SelfAndCrossPlay_save.p",
"exp_name": exp_name,
+ "classify_into_welfare_fn": True,
"train_n_replicates": train_n_replicates,
- "env_name": "IPD",
- # "env_name": "IMP",
- # "env_name": "AsymBoS",
+ "wandb": {
+ "project": "LOLA_Exact",
+ "group": exp_name,
+ "api_key_file": os.path.join(
+ os.path.dirname(__file__), "../../../api_key_wandb"
+ ),
+ },
+ # "env_name": "IPD" if env is None else env,
+ # "env_name": "IMP" if env is None else env,
+ "env_name": "IteratedAsymBoS" if env is None else env,
"num_episodes": 5 if debug else 50,
"trace_length": 5 if debug else 200,
+ "re_init_every_n_epi": 1,
+ # "num_episodes": 5 if debug else 50 * 200,
+ # "trace_length": 1,
+ # "re_init_every_n_epi": 50,
"simple_net": True,
"corrections": True,
"pseudo": False,
@@ -53,46 +89,66 @@ def main(debug):
# "with_linear_LR_decay_to_zero": True,
# "clip_update": 0.1,
# "lr": 0.001,
+ "plot_keys": aggregate_and_plot_tensorboard_data.PLOT_KEYS + ["ret"],
+ "plot_assemblage_tags": aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS
+ + [("ret",)],
+ "x_limits": (-0.1, 4.1),
+ "y_limits": (-0.1, 4.1),
}
- if hparams["load_plot_data"] is None:
- ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug)
- tune_analysis_per_exp = train(hparams)
- else:
- tune_analysis_per_exp = None
-
- evaluate(tune_analysis_per_exp, hparams)
- ray.shutdown()
+ hparams["plot_axis_scale_multipliers"] = (
+ 1 / hparams["trace_length"],
+ 1 / hparams["trace_length"],
+ )
+ return hparams
def train(hp):
- tune_config, stop, _ = get_tune_config(hp)
+ tune_config, stop_config, _ = get_tune_config(hp)
# Train with the Tune Class API (not an RLLib Trainer)
tune_analysis = tune.run(
- LOLAExact,
+ LOLAExactTrainer,
name=hp["exp_name"],
config=tune_config,
checkpoint_at_end=True,
- stop=stop,
+ stop=stop_config,
metric=hp["metric"],
mode="max",
+ # callbacks=None
+ # if hp["debug"]
+ # else [
+ # WandbLoggerCallback(
+ # project=hp["wandb"]["project"],
+ # group=hp["wandb"]["group"],
+ # api_key_file=hp["wandb"]["api_key_file"],
+ # log_config=True,
+ # )
+ # ],
)
- tune_analysis_per_exp = {"": tune_analysis}
+ if hp["classify_into_welfare_fn"]:
+ tune_analysis_per_exp = _split_tune_results_wt_welfare(tune_analysis)
+ else:
+ tune_analysis_per_exp = {"": tune_analysis}
+
return tune_analysis_per_exp
-def get_tune_config(hp: dict) -> dict:
+def get_tune_config(hp: dict):
tune_config = copy.deepcopy(hp)
- assert tune_config["env_name"] in ("IPD", "IMP", "BoS", "AsymBoS")
+ assert tune_config["env_name"] in ("IPD", "IMP", "BoS", "IteratedAsymBoS")
+
+ env_config = {
+ "players_ids": ["player_row", "player_col"],
+ "max_steps": tune_config["trace_length"],
+ "get_additional_info": True,
+ }
- if tune_config["env_name"] in ("IPD", "IMP", "BoS", "AsymBoS"):
- env_config = {
- "players_ids": ["player_row", "player_col"],
- "max_steps": tune_config["trace_length"],
- "get_additional_info": True,
- }
+ if tune_config["env_name"] == "IteratedAsymBoS":
+ tune_config["Q_net_std"] = 3.0
+ else:
+ tune_config["Q_net_std"] = 1.0
- if tune_config["env_name"] in ("IPD", "BoS", "AsymBoS"):
+ if tune_config["env_name"] in ("IPD", "BoS", "IteratedAsymBoS"):
tune_config["gamma"] = (
0.96 if tune_config["gamma"] is None else tune_config["gamma"]
)
@@ -103,8 +159,8 @@ def get_tune_config(hp: dict) -> dict:
)
tune_config["save_dir"] = "dice_results_imp"
- stop = {"episodes_total": tune_config["num_episodes"]}
- return tune_config, stop, env_config
+ stop_config = {"episodes_total": tune_config["num_episodes"]}
+ return tune_config, stop_config, env_config
def evaluate(tune_analysis_per_exp, hp):
@@ -113,7 +169,7 @@ def evaluate(tune_analysis_per_exp, hp):
rllib_config_eval,
policies_to_load,
trainable_class,
- stop,
+ stop_config,
env_config,
) = generate_eval_config(hp)
@@ -122,9 +178,12 @@ def evaluate(tune_analysis_per_exp, hp):
rllib_config_eval,
policies_to_load,
trainable_class,
- stop,
+ stop_config,
env_config,
tune_analysis_per_exp,
+ n_cross_play_per_checkpoint=min(15, hp["train_n_replicates"] - 1)
+ if hp["classify_into_welfare_fn"]
+ else None,
)
@@ -136,8 +195,8 @@ def generate_eval_config(hp):
hp_eval["batch_size"] = 1
hp_eval["num_episodes"] = 100
- tune_config, stop, env_config = get_tune_config(hp_eval)
- tune_config["TuneTrainerClass"] = LOLAExact
+ tune_config, stop_config, env_config = get_tune_config(hp_eval)
+ tune_config["TuneTrainerClass"] = LOLAExactTrainer
hp_eval["group_names"] = ["lola"]
hp_eval["scale_multipliers"] = (
@@ -154,7 +213,7 @@ def generate_eval_config(hp):
hp_eval["env_class"] = IteratedMatchingPennies
hp_eval["x_limits"] = (-1.0, 1.0)
hp_eval["y_limits"] = (-1.0, 1.0)
- elif hp_eval["env_name"] == "AsymBoS":
+ elif hp_eval["env_name"] == "IteratedAsymBoS":
hp_eval["env_class"] = IteratedAsymBoS
hp_eval["x_limits"] = (-0.1, 4.1)
hp_eval["y_limits"] = (-0.1, 4.1)
@@ -184,21 +243,53 @@ def generate_eval_config(hp):
},
"seed": hp_eval["seed"],
"min_iter_time_s": hp_eval["min_iter_time_s"],
+ "num_workers": 0,
+ "num_envs_per_worker": 1,
}
policies_to_load = copy.deepcopy(env_config["players_ids"])
- trainable_class = LOLAExact
+ trainable_class = LOLAExactTrainer
return (
hp_eval,
rllib_config_eval,
policies_to_load,
trainable_class,
- stop,
+ stop_config,
env_config,
)
+def _split_tune_results_wt_welfare(
+ tune_analysis,
+):
+ tune_analysis_per_welfare = {}
+ for trial in tune_analysis.trials:
+ welfare_name = _get_trial_welfare(trial)
+ if welfare_name not in tune_analysis_per_welfare.keys():
+ _add_empty_tune_analysis(
+ tune_analysis_per_welfare, welfare_name, tune_analysis
+ )
+ tune_analysis_per_welfare[welfare_name].trials.append(trial)
+ return tune_analysis_per_welfare
+
+
+def _get_trial_welfare(trial):
+ reward_player_1 = trial.last_result["ret1"]
+ reward_player_2 = trial.last_result["ret2"]
+ welfare_name = lola_exact_meta_game.classify_into_welfare_based_on_rewards(
+ reward_player_1, reward_player_2
+ )
+ return welfare_name
+
+
+def _add_empty_tune_analysis(
+ tune_analysis_per_welfare, welfare_name, tune_analysis
+):
+ tune_analysis_per_welfare[welfare_name] = copy.deepcopy(tune_analysis)
+ tune_analysis_per_welfare[welfare_name].trials = []
+
+
if __name__ == "__main__":
- debug_mode = True
+ debug_mode = False
main(debug_mode)
diff --git a/marltoolbox/experiments/tune_class_api/lola_pg_official.py b/marltoolbox/experiments/tune_class_api/lola_pg_official.py
index 080976b..1d2484c 100644
--- a/marltoolbox/experiments/tune_class_api/lola_pg_official.py
+++ b/marltoolbox/experiments/tune_class_api/lola_pg_official.py
@@ -11,25 +11,32 @@
##########
import copy
+import logging
import os
import time
import ray
from ray import tune
from ray.rllib.agents.dqn import DQNTorchPolicy
-from ray.tune.integration.wandb import WandbLogger
-from ray.tune.logger import DEFAULT_LOGGERS
+from ray.tune.integration.wandb import WandbLoggerCallback
-from marltoolbox.algos.lola import train_cg_tune_class_API
-from marltoolbox.algos.lola.train_pg_tune_class_API import LOLAPGMatrice
+from marltoolbox.algos.lola import (
+ train_cg_tune_class_API,
+ train_pg_tune_class_API,
+)
from marltoolbox.envs import (
vectorized_coin_game,
vectorized_mixed_motive_coin_game,
+ vectorized_ssd_mm_coin_game,
matrix_sequential_social_dilemma,
)
from marltoolbox.scripts import aggregate_and_plot_tensorboard_data
-from marltoolbox.utils import policy, log, self_and_cross_perf
+from marltoolbox import utils
+from marltoolbox.utils import policy, log, cross_play
from marltoolbox.utils.plot import PlotConfig
+from marltoolbox.experiments.tune_class_api import lola_exact_official
+
+logger = logging.getLogger(__name__)
def main(debug: bool, env=None):
@@ -40,68 +47,64 @@ def main(debug: bool, env=None):
:param debug: selection of debug mode using less compute
:param env: option to overwrite the env selection
"""
- train_n_replicates = 2 if debug else 1
+ train_n_replicates = 2 if debug else 10
timestamp = int(time.time())
seeds = [seed + timestamp for seed in list(range(train_n_replicates))]
exp_name, _ = log.log_in_current_day_dir("LOLA_PG")
+ tune_hparams = _get_hyperparameters(
+ debug, train_n_replicates, seeds, exp_name, env
+ )
+
+ if tune_hparams["load_plot_data"] is None:
+ ray.init(num_cpus=10, num_gpus=0, local_mode=debug)
+ tune_analysis_per_exp = _train(tune_hparams)
+ else:
+ tune_analysis_per_exp = None
+
+ _evaluate(tune_hparams, debug, tune_analysis_per_exp)
+ ray.shutdown()
+
+
+def _get_hyperparameters(debug, train_n_replicates, seeds, exp_name, env):
# The InfluenceEvader(like)
use_best_exploiter = False
- # use_best_exploiter = True
-
high_coop_speed_hp = True if use_best_exploiter else False
- # high_coop_speed_hp = True
+
+ gamma = 0.9
tune_hparams = {
"debug": debug,
"exp_name": exp_name,
"train_n_replicates": train_n_replicates,
- # wandb configuration
- "wandb": None
- if debug
- else {
+ "wandb": {
"project": "LOLA_PG",
"group": exp_name,
"api_key_file": os.path.join(
os.path.dirname(__file__), "../../../api_key_wandb"
),
- "log_config": True,
},
+ "classify_into_welfare_fn": True,
# Print metrics
"load_plot_data": None,
# Example: "load_plot_data": ".../SelfAndCrossPlay_save.p",
#
- # "gamma": 0.5,
- # "num_episodes": 3 if debug else 4000 if high_coop_speed_hp else 2000,
- # "trace_length": 4 if debug else 20,
- # "lr": None,
- #
- # "gamma": 0.875,
- # "lr": 0.005 / 4,
- # "num_episodes": 3 if debug else 4000,
- # "trace_length": 4 if debug else 20,
- #
- "gamma": 0.9375,
- "lr": 0.005 / 4
- if debug
- else tune.grid_search([0.005 / 4, 0.005 / 4 / 2, 0.005 / 4 / 2 / 2]),
- "num_episodes": 3 if debug else tune.grid_search([4000, 8000]),
- "trace_length": 4 if debug else tune.grid_search([40, 80]),
- #
- "batch_size": 8 if debug else 512,
# "env_name": "IteratedPrisonersDilemma" if env is None else env,
# "env_name": "IteratedAsymBoS" if env is None else env,
"env_name": "VectorizedCoinGame" if env is None else env,
# "env_name": "AsymVectorizedCoinGame" if env is None else env,
# "env_name": "VectorizedMixedMotiveCoinGame" if env is None else env,
+ # "env_name": "VectorizedSSDMixedMotiveCoinGame" if env is None else env,
+ "remove_trials_below_speed": False,
+ # "remove_trials_below_speed": 0.15,
"pseudo": False,
"grid_size": 3,
"lola_update": True,
"opp_model": False,
"mem_efficient": True,
"lr_correction": 1,
- "bs_mul": 1 / 10 * 3 if use_best_exploiter else 1 / 10,
+ "global_lr_divider": 1 / 10 * 3 if use_best_exploiter else 1 / 10,
"simple_net": True,
"hidden": 32,
"reg": 0,
@@ -117,20 +120,7 @@ def main(debug: bool, env=None):
"clip_loss_norm": False,
"clip_lola_update_norm": False,
"clip_lola_correction_norm": 3.0,
- # "clip_lola_correction_norm":
- # tune.grid_search([3.0 / 2, 3.0, 3.0 * 2]),
"clip_lola_actor_norm": 10.0,
- # "clip_lola_actor_norm": tune.grid_search([10.0 / 2, 10.0, 10.0 * 2]),
- "entropy_coeff": 0.001,
- # "entropy_coeff": tune.grid_search([0.001/2/2, 0.001/2, 0.001]),
- # "weigth_decay": 0.03,
- "weigth_decay": 0.03
- if debug
- else tune.grid_search([0.03 / 8 / 2 / 2, 0.03 / 8 / 2, 0.03 / 8]),
- # "lola_correction_multiplier": 1,
- "lola_correction_multiplier": 1
- if debug
- else tune.grid_search([1 * 4, 1 * 4 * 2, 1 * 4 * 2 * 2]),
"lr_decay": True,
"correction_reward_baseline_per_step": False,
"use_critic": False,
@@ -143,32 +133,119 @@ def main(debug: bool, env=None):
("total_reward",),
("entrop",),
],
+ "use_normalized_rewards": False,
+ "use_centered_reward": False,
+ "use_rolling_avg_actor_grad": False,
+ "process_reward_after_rolling": False,
+ "only_process_reward": False,
+ "use_rolling_avg_reward": False,
+ "reward_processing_bais": False,
+ "center_and_normalize_with_rolling_avg": False,
+ "punishment_helped": False,
}
- # Add exploiter hyperparameters
- tune_hparams.update(
- {
- "start_using_exploiter_at_update_n": 1
- if debug
- else 3000
- if high_coop_speed_hp
- else 1500,
- # PG exploiter
- "use_PG_exploiter": True if use_best_exploiter else False,
- "every_n_updates_copy_weights": 1 if debug else 100,
- # "adding_scaled_weights": False,
- # "adding_scaled_weights": 0.33,
- }
- )
-
- if tune_hparams["load_plot_data"] is None:
- ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug)
- tune_analysis_per_exp = _train(tune_hparams)
- else:
- tune_analysis_per_exp = None
+ if gamma == 0.5:
+ tune_hparams.update(
+ {
+ "gamma": 0.5,
+ "num_episodes": 3
+ if debug
+ else 4000
+ if high_coop_speed_hp
+ else 2000,
+ "trace_length": 10 if debug else 20,
+ "lr": None,
+ "weigth_decay": 0.03,
+ "lola_correction_multiplier": 1,
+ "entropy_coeff": 0.001,
+ "batch_size": 12 if debug else 512,
+ }
+ )
+ elif gamma == 0.875:
+ tune_hparams.update(
+ {
+ "gamma": 0.875,
+ "lr": 0.005 / 4,
+ "num_episodes": 3 if debug else 4000,
+ "trace_length": 10 if debug else 20,
+ "weigth_decay": 0.03 / 8,
+ "lola_correction_multiplier": 4,
+ "entropy_coeff": 0.001,
+ "batch_size": 12 if debug else 512,
+ }
+ )
+ elif gamma == 0.9375:
+ tune_hparams.update(
+ {
+ "gamma": 0.9375,
+ "lr": 0.005 / 4,
+ "num_episodes": 3 if debug else 2000,
+ "trace_length": 10 if debug else 40,
+ "weigth_decay": 0.03 / 32,
+ "lola_correction_multiplier": 4,
+ "entropy_coeff": 0.002,
+ "batch_size": 12 if debug else 1024,
+ }
+ )
+ elif gamma == 0.9:
+ tune_hparams.update(
+ {
+ "gamma": 0.9,
+ "lr": 0.005,
+ "num_episodes": 3 if debug else 2000,
+ "trace_length": 10 if debug else 40,
+ "weigth_decay": 0.03 / 16,
+ "lola_correction_multiplier": 8,
+ "entropy_coeff": 0.02,
+ "batch_size": 12 if debug else 1024,
+ "use_normalized_rewards": False,
+ "reward_processing_bais": 0.1,
+ "center_and_normalize_with_rolling_avg": False,
+ # "num_episodes": tune.grid_search(
+ # [
+ # 1000,
+ # 2000,
+ # 4000,
+ # ]
+ # ),
+ # "weigth_decay": tune.grid_search(
+ # [
+ # # 0.03 / 16,
+ # 0.03 / 16 * 10,
+ # # 0.03 / 16 * 100,
+ # ]
+ # ),
+ # "entropy_coeff": tune.grid_search(
+ # [
+ # 0.02 / 2,
+ # 0.02,
+ # 0.02 * 2,
+ # ]
+ # ),
+ # "use_normalized_rewards": tune.grid_search([False, True]),
+ # "center_and_normalize_with_rolling_avg": tune.grid_search(
+ # [False, True]
+ # ),
+ # "lola_correction_multiplier": tune.grid_search(
+ # [8.0 / 2.0, 8.0, 8.0 * 2.0]
+ # ),
+ }
+ )
- _evaluate(tune_hparams, debug, tune_analysis_per_exp)
- ray.shutdown()
+ if use_best_exploiter:
+ # Add exploiter hyperparameters
+ tune_hparams.update(
+ {
+ "start_using_exploiter_at_update_n": 1
+ if debug
+ else 3000
+ if high_coop_speed_hp
+ else 1500,
+ "use_PG_exploiter": True if use_best_exploiter else False,
+ "every_n_updates_copy_weights": 1 if debug else 100,
+ }
+ )
+ return tune_hparams
def _train(tune_hp):
@@ -177,7 +254,7 @@ def _train(tune_hp):
if "CoinGame" in tune_config["env_name"]:
trainable_class = train_cg_tune_class_API.LOLAPGCG
else:
- trainable_class = LOLAPGMatrice
+ trainable_class = train_pg_tune_class_API.LOLAPGMatrice
# Train with the Tune Class API (not RLLib Class)
tune_analysis = tune.run(
@@ -189,11 +266,28 @@ def _train(tune_hp):
metric=tune_config["metric"],
mode="max",
log_to_file=not tune_hp["debug"],
- loggers=DEFAULT_LOGGERS + (WandbLogger,),
+ callbacks=None
+ if tune_hp["debug"]
+ else [
+ WandbLoggerCallback(
+ project=tune_hp["wandb"]["project"],
+ group=tune_hp["wandb"]["group"],
+ api_key_file=tune_hp["wandb"]["api_key_file"],
+ log_config=True,
+ )
+ ],
)
- tune_analysis_per_exp = {"": tune_analysis}
- # if not tune_hp["debug"]:
+ if tune_hp["remove_trials_below_speed"]:
+ tune_analysis = _remove_failed_trials(tune_analysis, tune_hp)
+
+ if tune_hp["classify_into_welfare_fn"]:
+ tune_analysis_per_exp = _split_tune_results_wt_welfare(
+ tune_analysis, tune_hp
+ )
+ else:
+ tune_analysis_per_exp = {"": tune_analysis}
+
aggregate_and_plot_tensorboard_data.add_summary_plots(
main_path=os.path.join("~/ray_results/", tune_config["exp_name"]),
plot_keys=tune_config["plot_keys"],
@@ -203,6 +297,24 @@ def _train(tune_hp):
return tune_analysis_per_exp
+def _remove_failed_trials(results, tune_hp):
+ results = utils.tune_analysis.filter_trials(
+ results,
+ metric=f"player_red_pick_speed",
+ metric_threshold=tune_hp["remove_trials_below_speed"],
+ metric_mode="last-5-avg",
+ threshold_mode="above",
+ )
+ results = utils.tune_analysis.filter_trials(
+ results,
+ metric=f"player_blue_pick_speed",
+ metric_threshold=tune_hp["remove_trials_below_speed"],
+ metric_mode="last-5-avg",
+ threshold_mode="above",
+ )
+ return results
+
+
def _get_tune_config(tune_hp: dict, stop_on_epi_number: bool = False):
tune_config = copy.deepcopy(tune_hp)
@@ -220,24 +332,13 @@ def _get_tune_config(tune_hp: dict, stop_on_epi_number: bool = False):
tune_config[
"env_class"
] = vectorized_mixed_motive_coin_game.VectMixedMotiveCG
+ elif tune_config["env_name"] == "VectorizedSSDMixedMotiveCoinGame":
+ tune_config[
+ "env_class"
+ ] = vectorized_ssd_mm_coin_game.VectSSDMixedMotiveCG
else:
raise ValueError()
- tune_config["num_episodes"] = (
- 100000
- if tune_config["num_episodes"] is None
- else tune_config["num_episodes"]
- )
- tune_config["trace_length"] = (
- 150
- if tune_config["trace_length"] is None
- else tune_config["trace_length"]
- )
- tune_config["batch_size"] = (
- 4000
- if tune_config["batch_size"] is None
- else tune_config["batch_size"]
- )
tune_config["lr"] = (
0.005 if tune_config["lr"] is None else tune_config["lr"]
)
@@ -257,17 +358,30 @@ def _get_tune_config(tune_hp: dict, stop_on_epi_number: bool = False):
):
tune_hp["x_limits"] = (-2.0, 4.0)
tune_hp["y_limits"] = (-2.0, 4.0)
- tune_hp["jitter"] = 0.02
+ elif (
+ tune_config["env_class"]
+ == vectorized_ssd_mm_coin_game.VectSSDMixedMotiveCG
+ ):
+ tune_hp["x_limits"] = (-0.1, 1.5)
+ tune_hp["y_limits"] = (-0.1, 1.5)
+
+ tune_hp["jitter"] = 0.00
env_config = {
"players_ids": ["player_red", "player_blue"],
- "batch_size": tune_config["batch_size"],
- "max_steps": tune_config["trace_length"],
+ "batch_size": tune.sample_from(
+ lambda spec: spec.config["batch_size"]
+ ),
+ "max_steps": tune.sample_from(
+ lambda spec: spec.config["trace_length"]
+ ),
"grid_size": tune_config["grid_size"],
"get_additional_info": True,
"both_players_can_pick_the_same_coin": tune_config["env_name"]
- == "VectorizedMixedMotiveCoinGame",
+ == "VectorizedMixedMotiveCoinGame"
+ or tune_config["env_name"] == "VectorizedSSDMixedMotiveCoinGame",
"force_vectorize": False,
"same_obs_for_each_player": True,
+ "punishment_helped": tune_config["punishment_helped"],
}
tune_config["metric"] = "player_blue_pick_speed"
tune_config["plot_keys"] += (
@@ -324,17 +438,13 @@ def _get_tune_config(tune_hp: dict, stop_on_epi_number: bool = False):
}
tune_config["metric"] = "player_row_CC_freq"
- tune_hp["scale_multipliers"] = (
- 1 / tune_config["trace_length"],
- 1 / tune_config["trace_length"],
+ # For hyperparameter search
+ tune_hp["scale_multipliers"] = tune.sample_from(
+ lambda spec: (
+ 1 / spec.config["trace_length"],
+ 1 / spec.config["trace_length"],
+ )
)
- # For HP search
- # tune_hp["scale_multipliers"] = tune.sample_from(
- # lambda spec: (
- # 1 / spec.config["trace_length"],
- # 1 / spec.config["trace_length"],
- # )
- # )
tune_config["env_config"] = env_config
if stop_on_epi_number:
@@ -379,7 +489,17 @@ def _generate_eval_config(tune_hp, debug):
env_config["batch_size"] = 1
tune_config["TuneTrainerClass"] = train_cg_tune_class_API.LOLAPGCG
else:
- tune_config["TuneTrainerClass"] = LOLAPGMatrice
+ tune_config["TuneTrainerClass"] = train_pg_tune_class_API.LOLAPGMatrice
+ tune_config["env_config"].update(
+ {
+ "batch_size": env_config["batch_size"],
+ "max_steps": rllib_hp["trace_length"],
+ }
+ )
+ rllib_hp["scale_multipliers"] = (
+ 1 / rllib_hp["trace_length"],
+ 1 / rllib_hp["trace_length"],
+ )
rllib_config_eval = {
"env": rllib_hp["env_class"],
@@ -409,6 +529,8 @@ def _generate_eval_config(tune_hp, debug):
"callbacks": log.get_logging_callbacks_class(
log_full_epi=True,
),
+ "num_envs_per_worker": 1,
+ "num_workers": 0,
}
policies_to_load = copy.deepcopy(env_config["players_ids"])
@@ -421,7 +543,7 @@ def _generate_eval_config(tune_hp, debug):
"conv_filters": [[16, [3, 3], 1], [32, [3, 3], 1]],
}
else:
- trainable_class = LOLAPGMatrice
+ trainable_class = train_pg_tune_class_API.LOLAPGMatrice
return (
rllib_hp,
@@ -441,19 +563,22 @@ def _evaluate_self_and_cross_perf(
stop,
env_config,
tune_analysis_per_exp,
+ n_cross_play_per_checkpoint=None,
):
- evaluator = self_and_cross_perf.SelfAndCrossPlayEvaluator(
- exp_name=rllib_hp["exp_name"],
+ exp_name = os.path.join(rllib_hp["exp_name"], "eval")
+ evaluator = cross_play.evaluator.SelfAndCrossPlayEvaluator(
+ exp_name=exp_name,
local_mode=rllib_hp["debug"],
- use_wandb=not rllib_hp["debug"],
)
analysis_metrics_per_mode = evaluator.perform_evaluation_or_load_data(
evaluation_config=rllib_config_eval,
stop_config=stop,
policies_to_load_from_checkpoint=policies_to_load,
tune_analysis_per_exp=tune_analysis_per_exp,
- TuneTrainerClass=trainable_class,
- n_cross_play_per_checkpoint=min(5, rllib_hp["train_n_replicates"] - 1),
+ tune_trainer_class=trainable_class,
+ n_cross_play_per_checkpoint=min(5, rllib_hp["train_n_replicates"] - 1)
+ if n_cross_play_per_checkpoint is None
+ else n_cross_play_per_checkpoint,
to_load_path=rllib_hp["load_plot_data"],
)
@@ -461,7 +586,7 @@ def _evaluate_self_and_cross_perf(
rllib_hp["env_class"],
matrix_sequential_social_dilemma.MatrixSequentialSocialDilemma,
):
- background_area_coord = rllib_hp["env_class"].PAYOUT_MATRIX
+ background_area_coord = rllib_hp["env_class"].PAYOFF_MATRIX
else:
background_area_coord = None
@@ -485,6 +610,45 @@ def _evaluate_self_and_cross_perf(
)
+FAILURE = "failures"
+EGALITARIAN = "egalitarian"
+UTILITARIAN = "utilitarian"
+
+
+def _split_tune_results_wt_welfare(tune_analysis, hp):
+ tune_analysis_per_welfare = {}
+ for trial in tune_analysis.trials:
+ welfare_name = _get_trial_welfare(trial, hp)
+ if welfare_name not in tune_analysis_per_welfare.keys():
+ lola_exact_official._add_empty_tune_analysis(
+ tune_analysis_per_welfare, welfare_name, tune_analysis
+ )
+ tune_analysis_per_welfare[welfare_name].trials.append(trial)
+ return tune_analysis_per_welfare
+
+
+def _get_trial_welfare(trial, hp):
+ pick_own_player_1 = trial.last_result["player_red_pick_own_color"]
+ pick_own_player_2 = trial.last_result["player_blue_pick_own_color"]
+ welfare_name = lola_pg_classify_fn(
+ pick_own_player_1, pick_own_player_2, hp
+ )
+ return welfare_name
+
+
+def lola_pg_classify_fn(pick_own_player_1, pick_own_player_2, hp):
+ if hp["env_name"] == "VectorizedSSDMixedMotiveCoinGame":
+ if pick_own_player_1 < 0.75:
+ return UTILITARIAN
+ else:
+ return EGALITARIAN
+ else:
+ if pick_own_player_1 > 0.75:
+ return EGALITARIAN
+ else:
+ return UTILITARIAN
+
+
if __name__ == "__main__":
- debug_mode = True
+ debug_mode = False
main(debug_mode)
diff --git a/marltoolbox/experiments/tune_class_api/sos_exact_official.py b/marltoolbox/experiments/tune_class_api/sos_exact_official.py
new file mode 100644
index 0000000..4e512ba
--- /dev/null
+++ b/marltoolbox/experiments/tune_class_api/sos_exact_official.py
@@ -0,0 +1,263 @@
+##########
+# Additional dependencies are needed:
+# Follow the LOLA installation described in the
+# tune_class_api/lola_pg_official.py file
+##########
+
+import copy
+import os
+
+import ray
+from ray import tune
+from ray.rllib.agents.pg import PGTorchPolicy
+
+from marltoolbox.algos.sos import SOSTrainer
+from marltoolbox.envs.matrix_sequential_social_dilemma import (
+ IteratedPrisonersDilemma,
+ IteratedMatchingPennies,
+ IteratedAsymBoS,
+)
+from marltoolbox.experiments.tune_class_api import lola_exact_meta_game
+from marltoolbox.experiments.tune_class_api import lola_pg_official
+from marltoolbox.scripts import aggregate_and_plot_tensorboard_data
+from marltoolbox.utils import policy, log, miscellaneous
+
+
+def main(debug):
+ hparams = get_hyperparameters(debug)
+
+ if hparams["load_plot_data"] is None:
+ ray.init(
+ num_cpus=os.cpu_count(),
+ num_gpus=0,
+ local_mode=debug,
+ )
+ tune_analysis_per_exp = train(hparams)
+ else:
+ tune_analysis_per_exp = None
+
+ evaluate(tune_analysis_per_exp, hparams)
+ ray.shutdown()
+
+
+def get_hyperparameters(debug, train_n_replicates=None, env=None):
+ """Get hyperparameters for LOLA-Exact for matrix games"""
+
+ if train_n_replicates is None:
+ train_n_replicates = 2 if debug else int(3 * 2)
+ seeds = miscellaneous.get_random_seeds(train_n_replicates)
+
+ exp_name, _ = log.log_in_current_day_dir("SOS")
+
+ hparams = {
+ "debug": debug,
+ "load_plot_data": None,
+ "exp_name": exp_name,
+ "classify_into_welfare_fn": True,
+ "train_n_replicates": train_n_replicates,
+ "wandb": {
+ "project": "SOS",
+ "group": exp_name,
+ "api_key_file": os.path.join(
+ os.path.dirname(__file__), "../../../api_key_wandb"
+ ),
+ },
+ "env_name": "IteratedAsymBoS" if env is None else env,
+ "lr": 1.0 / 10,
+ "gamma": 0.96,
+ "num_epochs": 5 if debug else 100,
+ # "method": "lola",
+ "method": "sos",
+ "inital_weights_std": 1.0,
+ "seed": tune.grid_search(seeds),
+ "metric": "mean_reward_player_row",
+ "plot_keys": aggregate_and_plot_tensorboard_data.PLOT_KEYS
+ + ["mean_reward"],
+ "plot_assemblage_tags": aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS
+ + [("mean_reward",)],
+ "x_limits": (-0.1, 4.1),
+ "y_limits": (-0.1, 4.1),
+ "max_steps_in_eval": 100,
+ }
+
+ return hparams
+
+
+def train(hp):
+ tune_config, stop_config, _ = get_tune_config(hp)
+ # Train with the Tune Class API (not an RLLib Trainer)
+ tune_analysis = tune.run(
+ SOSTrainer,
+ name=hp["exp_name"],
+ config=tune_config,
+ checkpoint_at_end=True,
+ stop=stop_config,
+ metric=hp["metric"],
+ mode="max",
+ )
+ if hp["classify_into_welfare_fn"]:
+ tune_analysis_per_exp = _split_tune_results_wt_welfare(tune_analysis)
+ else:
+ tune_analysis_per_exp = {"": tune_analysis}
+
+ aggregate_and_plot_tensorboard_data.add_summary_plots(
+ main_path=os.path.join("~/ray_results/", tune_config["exp_name"]),
+ plot_keys=tune_config["plot_keys"],
+ plot_assemble_tags_in_one_plot=tune_config["plot_assemblage_tags"],
+ )
+
+ return tune_analysis_per_exp
+
+
+def get_tune_config(hp: dict):
+ tune_config = copy.deepcopy(hp)
+ assert tune_config["env_name"] in ("IPD", "IteratedAsymBoS")
+ env_config = {
+ "players_ids": ["player_row", "player_col"],
+ "max_steps": hp["max_steps_in_eval"],
+ "get_additional_info": True,
+ }
+ tune_config["plot_axis_scale_multipliers"] = (
+ (
+ 1 / hp["max_steps_in_eval"],
+ 1 / hp["max_steps_in_eval"],
+ ),
+ )
+ if "num_episodes" in tune_config:
+ stop_config = {"episodes_total": tune_config["num_episodes"]}
+ else:
+ stop_config = {"episodes_total": tune_config["num_epochs"]}
+
+ return tune_config, stop_config, env_config
+
+
+def evaluate(tune_analysis_per_exp, hp):
+ (
+ rllib_hp,
+ rllib_config_eval,
+ policies_to_load,
+ trainable_class,
+ stop_config,
+ env_config,
+ ) = generate_eval_config(hp)
+
+ lola_pg_official._evaluate_self_and_cross_perf(
+ rllib_hp,
+ rllib_config_eval,
+ policies_to_load,
+ trainable_class,
+ stop_config,
+ env_config,
+ tune_analysis_per_exp,
+ n_cross_play_per_checkpoint=min(15, hp["train_n_replicates"] - 1)
+ if hp["classify_into_welfare_fn"]
+ else None,
+ )
+
+
+def generate_eval_config(hp):
+ hp_eval = copy.deepcopy(hp)
+
+ hp_eval["min_iter_time_s"] = 3.0
+ hp_eval["seed"] = miscellaneous.get_random_seeds(1)[0]
+ hp_eval["batch_size"] = 1
+ hp_eval["num_episodes"] = 100
+
+ tune_config, stop_config, env_config = get_tune_config(hp_eval)
+ tune_config["TuneTrainerClass"] = SOSTrainer
+
+ hp_eval["group_names"] = ["lola"]
+ hp_eval["scale_multipliers"] = (
+ 1 / hp_eval["max_steps_in_eval"],
+ 1 / hp_eval["max_steps_in_eval"],
+ )
+ hp_eval["jitter"] = 0.05
+
+ if hp_eval["env_name"] == "IPD":
+ hp_eval["env_class"] = IteratedPrisonersDilemma
+ hp_eval["x_limits"] = (-3.5, 0.5)
+ hp_eval["y_limits"] = (-3.5, 0.5)
+ elif hp_eval["env_name"] == "IMP":
+ hp_eval["env_class"] = IteratedMatchingPennies
+ hp_eval["x_limits"] = (-1.0, 1.0)
+ hp_eval["y_limits"] = (-1.0, 1.0)
+ elif hp_eval["env_name"] == "IteratedAsymBoS":
+ hp_eval["env_class"] = IteratedAsymBoS
+ hp_eval["x_limits"] = (-0.1, 4.1)
+ hp_eval["y_limits"] = (-0.1, 4.1)
+ else:
+ raise NotImplementedError()
+
+ rllib_config_eval = {
+ "env": hp_eval["env_class"],
+ "env_config": env_config,
+ "multiagent": {
+ "policies": {
+ env_config["players_ids"][0]: (
+ policy.get_tune_policy_class(PGTorchPolicy),
+ hp_eval["env_class"](env_config).OBSERVATION_SPACE,
+ hp_eval["env_class"].ACTION_SPACE,
+ {"tune_config": copy.deepcopy(tune_config)},
+ ),
+ env_config["players_ids"][1]: (
+ policy.get_tune_policy_class(PGTorchPolicy),
+ hp_eval["env_class"](env_config).OBSERVATION_SPACE,
+ hp_eval["env_class"].ACTION_SPACE,
+ {"tune_config": copy.deepcopy(tune_config)},
+ ),
+ },
+ "policy_mapping_fn": lambda agent_id: agent_id,
+ "policies_to_train": ["None"],
+ },
+ "seed": hp_eval["seed"],
+ "min_iter_time_s": hp_eval["min_iter_time_s"],
+ "num_workers": 0,
+ "num_envs_per_worker": 1,
+ }
+
+ policies_to_load = copy.deepcopy(env_config["players_ids"])
+ trainable_class = SOSTrainer
+
+ return (
+ hp_eval,
+ rllib_config_eval,
+ policies_to_load,
+ trainable_class,
+ stop_config,
+ env_config,
+ )
+
+
+def _split_tune_results_wt_welfare(
+ tune_analysis,
+):
+ tune_analysis_per_welfare = {}
+ for trial in tune_analysis.trials:
+ welfare_name = _get_trial_welfare(trial)
+ if welfare_name not in tune_analysis_per_welfare.keys():
+ _add_empty_tune_analysis(
+ tune_analysis_per_welfare, welfare_name, tune_analysis
+ )
+ tune_analysis_per_welfare[welfare_name].trials.append(trial)
+ return tune_analysis_per_welfare
+
+
+def _get_trial_welfare(trial):
+ reward_player_1 = trial.last_result["mean_reward_player_row"]
+ reward_player_2 = trial.last_result["mean_reward_player_col"]
+ welfare_name = lola_exact_meta_game.classify_into_welfare_based_on_rewards(
+ reward_player_1, reward_player_2
+ )
+ return welfare_name
+
+
+def _add_empty_tune_analysis(
+ tune_analysis_per_welfare, welfare_name, tune_analysis
+):
+ tune_analysis_per_welfare[welfare_name] = copy.deepcopy(tune_analysis)
+ tune_analysis_per_welfare[welfare_name].trials = []
+
+
+if __name__ == "__main__":
+ debug_mode = False
+ main(debug_mode)
diff --git a/marltoolbox/experiments/tune_class_api/various_algo_meta_game.py b/marltoolbox/experiments/tune_class_api/various_algo_meta_game.py
new file mode 100644
index 0000000..f59e8e3
--- /dev/null
+++ b/marltoolbox/experiments/tune_class_api/various_algo_meta_game.py
@@ -0,0 +1,1030 @@
+import copy
+import os
+
+import numpy as np
+import pandas as pd
+import ray
+import torch
+from ray import tune
+from ray.rllib.agents import dqn
+from ray.rllib.agents.pg import PGTrainer
+from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy, pg_loss_stats
+
+from marltoolbox import utils
+from marltoolbox.algos import population, welfare_coordination
+from marltoolbox.algos.lola.train_exact_tune_class_API import LOLAExactTrainer
+from marltoolbox.algos.stochastic_population import StochasticPopulation
+from marltoolbox.algos.welfare_coordination import MetaGameSolver
+from marltoolbox.envs.matrix_sequential_social_dilemma import (
+ TwoPlayersCustomizableMatrixGame,
+)
+from marltoolbox.examples.rllib_api import pg_ipd
+from marltoolbox.experiments.rllib_api import amtft_meta_game
+from marltoolbox.experiments.rllib_api import amtft_various_env
+from marltoolbox.experiments.tune_class_api import (
+ lola_exact_meta_game,
+ lola_exact_official,
+)
+from marltoolbox.experiments.tune_class_api import sos_exact_official
+from marltoolbox.scripts import aggregate_and_plot_tensorboard_data
+from marltoolbox.utils import (
+ log,
+ miscellaneous,
+ callbacks,
+)
+
+
+def main(debug, base_game_algo=None, meta_game_algo=None):
+ """Evaluate meta game performances"""
+
+ train_n_replicates = 1
+ seeds = miscellaneous.get_random_seeds(train_n_replicates)
+ exp_name, _ = log.log_in_current_day_dir("meta_game_compare")
+
+ ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug)
+ hparams = _get_hyperparameters(
+ debug, seeds, exp_name, base_game_algo, meta_game_algo
+ )
+ (
+ hparams["payoff_matrices"],
+ hparams["actions_possible"],
+ hparams["base_ckpt_per_replicat"],
+ ) = _form_n_matrices_from_base_game_payoffs(hparams)
+ hparams["meta_game_policy_distributions"] = _train_meta_policies(hparams)
+ tune_analysis, hp_eval = _evaluate_in_base_game(hparams)
+ ray.shutdown()
+
+ _extract_metric_and_log_and_plot(tune_analysis, hparams, hp_eval)
+
+
+def _extract_metric_and_log_and_plot(tune_analysis, hparams, hp_eval):
+ (
+ mean_player_1_payoffs,
+ mean_player_2_payoffs,
+ ) = _extract_metrics(tune_analysis, hparams)
+ results = []
+ for player1_avg_r_one_replicate, player2_avg_r_one_replicate in zip(
+ mean_player_1_payoffs, mean_player_2_payoffs
+ ):
+ results.append(
+ (player1_avg_r_one_replicate, player2_avg_r_one_replicate)
+ )
+ amtft_meta_game.save_to_json(exp_name=hparams["exp_name"], object=results)
+ amtft_meta_game.plot_results(
+ exp_name=hparams["exp_name"],
+ results=results,
+ hp_eval=hp_eval,
+ format_fn=_format_result_for_plotting,
+ jitter=0.05,
+ )
+
+
+BASE_AMTFT = "amTFT"
+BASE_LOLAExact = "base LOLA-Exact"
+META_LOLAExact = "meta LOLA-Exact"
+META_PG = "PG"
+META_SOS = "SOS"
+META_APLHA_RANK = "alpha-rank"
+META_APLHA_PURE = "alpha-rank pure strategy"
+META_REPLICATOR_DYNAMIC = "replicator dynamic"
+META_REPLICATOR_DYNAMIC_ZERO_INIT = "replicator dynamic with zero init"
+META_RANDOM = "Random"
+
+
+def _get_hyperparameters(
+ debug, seeds, exp_name, base_game_algo=None, meta_game_algo=None
+):
+ hp = {
+ # "base_game_policies": BASE_AMTFT,
+ "base_game_policies": BASE_LOLAExact,
+ #
+ # "meta_game_policies": META_PG,
+ # "meta_game_policies": META_LOLAExact,
+ # "meta_game_policies": META_APLHA_RANK,
+ # "meta_game_policies": META_APLHA_PURE,
+ # "meta_game_policies": META_REPLICATOR_DYNAMIC,
+ # "meta_game_policies": META_REPLICATOR_DYNAMIC_ZERO_INIT,
+ "meta_game_policies": META_RANDOM,
+ #
+ "apply_announcement_protocol": True,
+ #
+ "players_ids": ["player_row", "player_col"],
+ "use_r2d2": True,
+ }
+
+ if base_game_algo is not None:
+ hp["base_game_policies"] = base_game_algo
+ if meta_game_algo is not None:
+ hp["meta_game_policies"] = meta_game_algo
+
+ if hp["base_game_policies"] == BASE_AMTFT:
+ hp.update(
+ amtft_meta_game.get_hyperparameters(
+ debug=debug, use_r2d2=hp["use_r2d2"]
+ )
+ )
+ elif hp["base_game_policies"] == BASE_LOLAExact:
+ hp.update(lola_exact_meta_game.get_hyperparameters(debug=debug))
+ else:
+ raise ValueError()
+
+ hp.update(
+ {
+ "debug": debug,
+ "seeds": seeds,
+ "exp_name": exp_name,
+ "wandb": {
+ "project": "meta_game_compare",
+ "group": exp_name,
+ "api_key_file": os.path.join(
+ os.path.dirname(__file__), "../../../api_key_wandb"
+ ),
+ },
+ }
+ )
+
+ players_ids = ["player_row", "player_col"]
+ hp["x_axis_metric"] = f"policy_reward_mean.{players_ids[0]}"
+ hp["y_axis_metric"] = f"policy_reward_mean.{players_ids[1]}"
+ return hp
+
+
+payoffs_per_groups = None
+
+
+def _form_n_matrices_from_base_game_payoffs(hp):
+ global payoffs_per_groups
+ payoffs_per_groups = _get_payoffs_for_every_group_of_base_game_replicates(
+ hp
+ )
+ # In eval
+ if hp["base_game_policies"] == BASE_AMTFT:
+ if hp["use_r2d2"]:
+ # I removed the change to 100 step in eval when I moved to R2D2
+ n_steps_per_epi = 20
+ else:
+ n_steps_per_epi = 100
+ elif hp["base_game_policies"] == BASE_LOLAExact:
+ if hp["debug"]:
+ n_steps_per_epi = 40
+ else:
+ n_steps_per_epi = 1 # in fact 200 but the payoffs are already avg
+ else:
+ raise ValueError()
+
+ if hp["apply_announcement_protocol"]:
+ (
+ payoffs_matrices,
+ actions_possible,
+ base_ckpt_per_replicat,
+ ) = _aggregate_payoffs_groups_into_matrices_wt_announcement_protocol(
+ payoffs_per_groups, n_steps_per_epi, hp
+ )
+ else:
+ (
+ payoffs_matrices,
+ actions_possible,
+ base_ckpt_per_replicat,
+ ) = _aggregate_payoffs_groups_into_matrices(
+ payoffs_per_groups, n_steps_per_epi, hp
+ )
+
+ return payoffs_matrices, actions_possible, base_ckpt_per_replicat
+
+
+def _get_payoffs_for_every_group_of_base_game_replicates(hp):
+ if hp["base_game_policies"] == BASE_AMTFT:
+ module = amtft_meta_game
+ elif hp["base_game_policies"] == BASE_LOLAExact:
+ module = lola_exact_meta_game
+ else:
+ raise ValueError()
+
+ payoffs_per_groups = []
+ for i in range(hp["n_replicates_over_full_exp"]):
+ hp_replicate_i = module._load_base_game_results(
+ copy.deepcopy(hp), load_base_replicate_i=i
+ )
+
+ all_welfare_pairs_wt_payoffs = (
+ module._get_all_welfare_pairs_wt_cross_play_payoffs(
+ hp_replicate_i, hp_replicate_i["players_ids"]
+ )
+ )
+ payoffs_per_groups.append(
+ (all_welfare_pairs_wt_payoffs, hp_replicate_i)
+ )
+ return payoffs_per_groups
+
+
+def _aggregate_payoffs_groups_into_matrices_wt_announcement_protocol(
+ payoffs_per_groups, n_steps_per_epi, hp
+):
+ payoffs_matrices = []
+ ckpt_per_replicat = []
+ all_welfare_fn_sets = []
+ for i, (payoffs_for_one_group, hp_replicat_i) in enumerate(
+ payoffs_per_groups
+ ):
+ ckpt_per_replicat.append(hp_replicat_i["ckpt_per_welfare"])
+
+ announcement_protocol_solver_p1 = welfare_coordination.MetaGameSolver()
+ announcement_protocol_solver_p1.setup_meta_game(
+ payoffs_per_groups[i][0],
+ own_player_idx=0,
+ opp_player_idx=1,
+ own_default_welfare_fn="utilitarian",
+ opp_default_welfare_fn="inequity aversion"
+ if hp["base_game_policies"] == BASE_AMTFT
+ else "egalitarian",
+ )
+
+ welfare_fn_sets = announcement_protocol_solver_p1.welfare_fn_sets
+ n_set_of_welfare_sets = len(welfare_fn_sets)
+ payoff_matrix = np.empty(
+ shape=(n_set_of_welfare_sets, n_set_of_welfare_sets, 2),
+ dtype=np.float,
+ )
+ for own_welfare_set_idx, own_welfare_set_announced in enumerate(
+ welfare_fn_sets
+ ):
+ for opp_welfare_set_idx, opp_wefare_set in enumerate(
+ welfare_fn_sets
+ ):
+ cell_payoffs = (
+ announcement_protocol_solver_p1._compute_meta_payoff(
+ own_welfare_set_announced, opp_wefare_set
+ )
+ )
+ payoff_matrix[own_welfare_set_idx, opp_welfare_set_idx, 0] = (
+ cell_payoffs[0] / n_steps_per_epi
+ )
+ payoff_matrix[own_welfare_set_idx, opp_welfare_set_idx, 1] = (
+ cell_payoffs[1] / n_steps_per_epi
+ )
+ amtft_meta_game.save_to_json(
+ exp_name=hp["exp_name"],
+ object=payoff_matrix.tolist(),
+ filename=f"payoffs_matrices_{i}.json",
+ )
+ payoffs_matrices.append(payoff_matrix)
+ all_welfare_fn_sets.append(tuple(welfare_fn_sets))
+ assert len(set(all_welfare_fn_sets)) == 1
+ return payoffs_matrices, welfare_fn_sets, ckpt_per_replicat
+
+
+def _aggregate_payoffs_groups_into_matrices(
+ payoffs_per_groups, n_steps_per_epi, hp
+):
+ payoff_matrices = []
+ ckpt_per_replicat = []
+ all_welfares_fn = None
+ for i, (payoffs_for_one_group, hp_replicat_i) in enumerate(
+ payoffs_per_groups
+ ):
+ (
+ one_payoff_matrice,
+ tmp_all_welfares_fn,
+ ) = _aggregate_payoffs_in_one_matrix(
+ payoffs_for_one_group, n_steps_per_epi
+ )
+ amtft_meta_game.save_to_json(
+ exp_name=hp["exp_name"],
+ object=one_payoff_matrice.tolist(),
+ filename=f"payoffs_matrices_{i}.json",
+ )
+ payoff_matrices.append(one_payoff_matrice)
+ ckpt_per_replicat.append(hp_replicat_i["ckpt_per_welfare"])
+ if all_welfares_fn is None:
+ all_welfares_fn = tmp_all_welfares_fn
+ assert len(all_welfares_fn) == len(
+ tmp_all_welfares_fn
+ ), f"{len(all_welfares_fn)} == {len(tmp_all_welfares_fn)}"
+ return payoff_matrices, all_welfares_fn, ckpt_per_replicat
+
+
+def _aggregate_payoffs_in_one_matrix(payoffs_for_one_group, n_steps_per_epi):
+ all_welfares_fn = MetaGameSolver.list_all_welfares_fn(
+ payoffs_for_one_group
+ )
+ all_welfares_fn = sorted(tuple(all_welfares_fn))
+ n_welfare_fn = len(all_welfares_fn)
+ payoff_matrix = np.empty(
+ shape=(n_welfare_fn, n_welfare_fn, 2), dtype=np.float
+ )
+ for row_i, welfare_player_1 in enumerate(all_welfares_fn):
+ for col_i, welfare_player_2 in enumerate(all_welfares_fn):
+ welfare_pair_name = (
+ MetaGameSolver.from_pair_of_welfare_names_to_key(
+ welfare_player_1, welfare_player_2
+ )
+ )
+ payoff_matrix[row_i, col_i, 0] = (
+ payoffs_for_one_group[welfare_pair_name][0] / n_steps_per_epi
+ )
+ payoff_matrix[row_i, col_i, 1] = (
+ payoffs_for_one_group[welfare_pair_name][1] / n_steps_per_epi
+ )
+ return payoff_matrix, all_welfares_fn
+
+
+def _train_meta_policies(hp):
+ hp["exp_name"] = os.path.join(hp["exp_name"], "meta_game")
+
+ if hp["meta_game_policies"] == META_PG:
+ meta_policies = _train_meta_policy_using_pg(hp)
+ elif hp["meta_game_policies"] == META_LOLAExact:
+ meta_policies = _train_meta_policy_using_lola_exact(hp)
+ elif hp["meta_game_policies"] == META_APLHA_RANK:
+ meta_policies = _train_meta_policy_using_alpha_rank(hp)
+ elif hp["meta_game_policies"] == META_APLHA_PURE:
+ meta_policies = _train_meta_policy_using_alpha_rank(
+ hp, pure_strategy=True
+ )
+ elif (
+ hp["meta_game_policies"] == META_REPLICATOR_DYNAMIC
+ or hp["meta_game_policies"] == META_REPLICATOR_DYNAMIC_ZERO_INIT
+ ):
+ meta_policies = _train_meta_policy_using_replicator_dynamic(hp)
+ elif hp["meta_game_policies"] == META_RANDOM:
+ meta_policies = _get_random_meta_policy(hp)
+ elif hp["meta_game_policies"] == META_SOS:
+ meta_policies = _train_meta_policy_using_sos_exact(hp)
+ else:
+ raise ValueError()
+
+ meta_policies = _clamp_policies_normalize(meta_policies)
+
+ return meta_policies
+
+
+def _clamp_policies_normalize(meta_policies):
+ for i in range(len(meta_policies)):
+ for player_key, player_meta_pi in meta_policies[i].items():
+ assert not (
+ any(player_meta_pi > 1.01) or any(player_meta_pi < -0.01)
+ ), f"player_meta_pi {player_meta_pi}"
+ player_meta_pi = player_meta_pi / player_meta_pi.sum()
+ meta_policies[i][player_key] = player_meta_pi.clamp(
+ min=0.0, max=1.0
+ )
+ return meta_policies
+
+
+def _train_meta_policy_using_pg(hp):
+ rllib_config, stop_config = pg_ipd.get_rllib_config(
+ hp["seeds"], hp["debug"]
+ )
+ rllib_config, stop_config = _modify_rllib_config_for_meta_pg_policy(
+ rllib_config, stop_config, hp
+ )
+
+ tune_analysis = _train_with_tune(rllib_config, stop_config, hp, PGTrainer)
+
+ return _extract_policy_pg(tune_analysis)
+
+
+def _extract_policy_pg(tune_analysis):
+ policies = []
+ for trial in tune_analysis.trials:
+ next_act_distrib_idx = 0
+ p1_act_distrib = []
+ p2_act_distrib = []
+ p1_info = trial.last_result["info"]["learner"]["player_row"]
+ p2_info = trial.last_result["info"]["learner"]["player_col"]
+ prefix = "act_dist_inputs_single_act"
+ while True:
+ p1_act_distrib.append(p1_info[f"{prefix}{next_act_distrib_idx}"])
+ p2_act_distrib.append(p2_info[f"{prefix}{next_act_distrib_idx}"])
+ next_act_distrib_idx += 1
+ if f"{prefix}{next_act_distrib_idx}" not in p1_info.keys():
+ break
+ policy_player_1 = torch.softmax(torch.tensor(p1_act_distrib), dim=0)
+ policy_player_2 = torch.softmax(torch.tensor(p2_act_distrib), dim=0)
+ policies.append(
+ {"player_row": policy_player_1, "player_col": policy_player_2}
+ )
+ print("PG meta policy extracted")
+ print("policy_player_1 ", policy_player_1)
+ print("policy_player_2 ", policy_player_2)
+ return policies
+
+
+def _modify_rllib_config_for_meta_pg_policy(rllib_config, stop_config, hp):
+ rllib_config["env"] = TwoPlayersCustomizableMatrixGame
+ rllib_config["env_config"]["NUM_ACTIONS"] = len(hp["actions_possible"])
+ rllib_config["env_config"]["all_welfares_fn"] = hp["actions_possible"]
+ rllib_config["env_config"]["max_steps"] = 1
+ rllib_config["model"] = {
+ # Number of hidden layers for fully connected net
+ "fcnet_hiddens": [64],
+ # Nonlinearity for fully connected net (tanh, relu)
+ "fcnet_activation": "relu",
+ }
+ rllib_config["lr"] = 0.003
+ stop_config["episodes_total"] = 10 if hp["debug"] else 8000
+
+ rllib_config["env_config"]["linked_data"] = _get_payoff_matrix_grid_search(
+ hp
+ )
+ rllib_config["seed"] = tune.sample_from(
+ lambda spec: spec.config["env_config"]["linked_data"][0]
+ )
+ rllib_config["env_config"]["PAYOFF_MATRIX"] = tune.sample_from(
+ lambda spec: spec.config["env_config"]["linked_data"][1]
+ )
+
+ rllib_config = _dynamicaly_change_policies_spaces(hp, rllib_config)
+
+ return rllib_config, stop_config
+
+
+def _dynamicaly_change_policies_spaces(hp, rllib_config):
+ MyPGTorchPolicy = PGTorchPolicy.with_updates(
+ stats_fn=log.augment_stats_fn_wt_additionnal_logs(
+ stats_function=pg_loss_stats
+ )
+ )
+
+ tmp_env_config = copy.deepcopy(rllib_config["env_config"])
+ tmp_env_config["PAYOFF_MATRIX"] = hp["payoff_matrices"][0]
+ tmp_env = rllib_config["env"](tmp_env_config)
+ for policy_id, policy_config in rllib_config["multiagent"][
+ "policies"
+ ].items():
+ policy_config = list(policy_config)
+ policy_config[0] = MyPGTorchPolicy
+ policy_config[1] = tmp_env.OBSERVATION_SPACE
+ policy_config[2] = tmp_env.ACTION_SPACE
+ rllib_config["multiagent"]["policies"][policy_id] = tuple(
+ policy_config
+ )
+ return rllib_config
+
+
+def _train_meta_policy_using_lola_exact(hp):
+ lola_exact_hp = lola_exact_meta_game.get_hyperparameters(hp["debug"])
+
+ tune_config, stop_config, _ = lola_exact_official.get_tune_config(
+ lola_exact_hp
+ )
+
+ tune_config, stop_config = _modify_tune_config_for_meta_lola_exact(
+ hp, tune_config, stop_config
+ )
+
+ tune_analysis = _train_with_tune(
+ tune_config, stop_config, hp, LOLAExactTrainer
+ )
+ return _extract_policy_lola_exact(tune_analysis)
+
+
+def _extract_policy_lola_exact(tune_analysis):
+ policies = []
+ for trial in tune_analysis.trials:
+ policy_player_1 = trial.last_result["policy1"][-1, :]
+ policy_player_2 = trial.last_result["policy2"][-1, :]
+ policy_player_1 = torch.tensor(policy_player_1)
+ policy_player_2 = torch.tensor(policy_player_2)
+ policies.append(
+ {"player_row": policy_player_1, "player_col": policy_player_2}
+ )
+ print("LOLA-Exact meta policy extracted")
+ print("policy_player_1 ", policy_player_1)
+ print("policy_player_2 ", policy_player_2)
+ return policies
+
+
+def _train_meta_policy_using_sos_exact(hp):
+ lola_exact_hp = sos_exact_official.get_hyperparameters(hp["debug"])
+
+ tune_config, stop_config, _ = sos_exact_official.get_tune_config(
+ lola_exact_hp
+ )
+
+ tune_config, stop_config = _modify_tune_config_for_meta_sos_exact(
+ hp, tune_config, stop_config
+ )
+
+ tune_analysis = _train_with_tune(
+ tune_config, stop_config, hp, LOLAExactTrainer
+ )
+ return _extract_policy_lola_exact(tune_analysis)
+
+
+def _modify_tune_config_for_meta_sos_exact(hp, tune_config, stop_config):
+ tune_config["env_name"] = None
+ tune_config["all_welfares_fn"] = hp["actions_possible"]
+ tune_config["method"] = "sos"
+ tune_config["linked_data"] = _get_payoff_matrix_grid_search(hp)
+ tune_config["seed"] = tune.sample_from(
+ lambda spec: spec.config["linked_data"][0]
+ )
+ tune_config["custom_payoff_matrix"] = tune.sample_from(
+ lambda spec: spec.config["linked_data"][1]
+ )
+
+ return tune_config, stop_config
+
+
+def _train_with_tune(
+ rllib_config,
+ stop_config,
+ hp,
+ trainer,
+ wandb=True,
+ plot_aggregates=True,
+):
+ tune_analysis = tune.run(
+ trainer,
+ config=rllib_config,
+ stop=stop_config,
+ name=hp["exp_name"],
+ # log_to_file=False if hp["debug"] else True,
+ # callbacks=None
+ # if hp["debug"] or not wandb
+ # else [
+ # WandbLoggerCallback(
+ # project=hp["wandb"]["project"],
+ # group=hp["wandb"]["group"],
+ # api_key_file=hp["wandb"]["api_key_file"],
+ # log_config=True,
+ # ),
+ # ],
+ )
+
+ if not hp["debug"] and plot_aggregates:
+ aggregate_and_plot_tensorboard_data.add_summary_plots(
+ main_path=os.path.join("~/ray_results/", hp["exp_name"]),
+ plot_keys=hp["plot_keys"],
+ plot_assemble_tags_in_one_plot=hp["plot_assemblage_tags"],
+ )
+ return tune_analysis
+
+
+def _modify_tune_config_for_meta_lola_exact(hp, tune_config, stop_config):
+ stop_config["episodes_total"] *= tune_config["trace_length"]
+ tune_config["re_init_every_n_epi"] *= tune_config["trace_length"]
+ tune_config["trace_length"] = 1
+ tune_config["env_name"] = "custom_payoff_matrix"
+ tune_config["all_welfares_fn"] = hp["actions_possible"]
+ tune_config["linked_data"] = _get_payoff_matrix_grid_search(hp)
+ tune_config["seed"] = tune.sample_from(
+ lambda spec: spec.config["linked_data"][0]
+ )
+ tune_config["custom_payoff_matrix"] = tune.sample_from(
+ lambda spec: spec.config["linked_data"][1]
+ )
+
+ return tune_config, stop_config
+
+
+def _get_payoff_matrix_grid_search(hp):
+ # payoff_matrices = [
+ # payoff_matrice_data for payoff_matrice_data in hp["payoff_matrices"]
+ # ]
+ payoff_matrices = copy.deepcopy(hp["payoff_matrices"])
+ seeds = miscellaneous.get_random_seeds(len(payoff_matrices))
+ linked_data = [
+ (seed, matrix) for seed, matrix in zip(seeds, payoff_matrices)
+ ]
+ return tune.grid_search(linked_data)
+
+
+def _evaluate_in_base_game(hp):
+ hp["exp_name"] = os.path.join(hp["exp_name"], "final_base_game")
+ all_rllib_configs = []
+ for meta_game_idx in range(hp["n_replicates_over_full_exp"]):
+ (
+ rllib_config,
+ stop_config,
+ trainer,
+ hp_eval,
+ ) = _get_final_base_game_rllib_config(copy.deepcopy(hp), meta_game_idx)
+
+ all_rllib_configs.append(rllib_config)
+
+ master_rllib_config = amtft_meta_game._mix_rllib_config(
+ all_rllib_configs, hp_eval=hp
+ )
+ tune_analysis = _train_with_tune(
+ master_rllib_config,
+ stop_config,
+ hp,
+ trainer,
+ wandb=False,
+ plot_aggregates=False,
+ )
+ return tune_analysis, hp_eval
+
+
+def _get_final_base_game_rllib_config(hp, meta_game_idx):
+ if hp["base_game_policies"] == BASE_AMTFT:
+ (
+ stop_config,
+ env_config,
+ rllib_config,
+ trainer,
+ hp_eval,
+ ) = _get_rllib_config_for_base_amTFT_policy(hp)
+ elif hp["base_game_policies"] == BASE_LOLAExact:
+ (
+ stop_config,
+ env_config,
+ rllib_config,
+ trainer,
+ hp_eval,
+ ) = _get_rllib_config_for_base_lola_exact_policy(hp)
+ else:
+ raise ValueError()
+
+ (
+ rllib_config,
+ stop_config,
+ ) = _change_simple_rllib_config_for_final_base_game_eval(
+ hp, rllib_config, stop_config
+ )
+
+ if hp["apply_announcement_protocol"]:
+ rllib_config = _change_rllib_config_to_use_welfare_coordination(
+ hp, rllib_config, meta_game_idx, env_config
+ )
+ else:
+ rllib_config = _change_rllib_config_to_use_stochastic_populations(
+ hp, rllib_config, meta_game_idx
+ )
+
+ return rllib_config, stop_config, trainer, hp_eval
+
+
+def _get_rllib_config_for_base_amTFT_policy(hp):
+ hp_eval = amtft_various_env.get_hyperparameters(
+ hp["debug"],
+ train_n_replicates=1,
+ env="IteratedAsymBoS",
+ use_r2d2=hp["use_r2d2"],
+ )
+ hp_eval = amtft_various_env.modify_hyperparams_for_the_selected_env(
+ hp_eval
+ )
+
+ (
+ rllib_config,
+ env_config,
+ stop_config,
+ hp_eval,
+ ) = amtft_various_env._generate_eval_config(hp_eval)
+
+ if hp["use_r2d2"]:
+ trainer = dqn.r2d2.R2D2Trainer
+ else:
+ trainer = dqn.dqn.DQNTrainer
+
+ return stop_config, env_config, rllib_config, trainer, hp_eval
+
+
+def _get_rllib_config_for_base_lola_exact_policy(hp):
+ lola_exact_hp = lola_exact_official.get_hyperparameters(
+ debug=hp["debug"], env="IteratedAsymBoS", train_n_replicates=1
+ )
+ (
+ hp_eval,
+ rllib_config,
+ policies_to_load,
+ trainable_class,
+ stop_config,
+ env_config,
+ ) = lola_exact_official.generate_eval_config(lola_exact_hp)
+
+ trainer = PGTrainer
+
+ return stop_config, env_config, rllib_config, trainer, lola_exact_hp
+
+
+def _change_simple_rllib_config_for_final_base_game_eval(
+ hp, rllib_config, stop_config
+):
+ rllib_config["multiagent"]["policies_to_train"] = ["None"]
+ rllib_config["callbacks"] = callbacks.merge_callbacks(
+ callbacks.PolicyCallbacks,
+ log.get_logging_callbacks_class(
+ log_full_epi=True,
+ log_full_epi_interval=1,
+ log_from_policy_in_evaluation=True,
+ ),
+ )
+ rllib_config["seed"] = tune.sample_from(
+ lambda spec: miscellaneous.get_random_seeds(1)[0]
+ )
+ if not hp["debug"]:
+ stop_config["episodes_total"] = 10
+ return rllib_config, stop_config
+
+
+def _change_rllib_config_to_use_welfare_coordination(
+ hp, rllib_config, meta_game_idx, env_config
+):
+ global payoffs_per_groups
+ all_welfare_pairs_wt_payoffs = payoffs_per_groups[meta_game_idx][0]
+
+ rllib_config["multiagent"]["policies_to_train"] = ["None"]
+ policies = rllib_config["multiagent"]["policies"]
+ for policy_idx, policy_id in enumerate(env_config["players_ids"]):
+ policy_config_items = list(policies[policy_id])
+ opp_policy_idx = (policy_idx + 1) % 2
+
+ egalitarian_welfare_name = (
+ "inequity aversion"
+ if hp["base_game_policies"] == BASE_AMTFT
+ else "egalitarian"
+ )
+ meta_policy_config = copy.deepcopy(welfare_coordination.DEFAULT_CONFIG)
+ meta_policy_config.update(
+ {
+ "nested_policies": [
+ {
+ "Policy_class": copy.deepcopy(policy_config_items[0]),
+ "config_update": copy.deepcopy(policy_config_items[3]),
+ },
+ ],
+ "all_welfare_pairs_wt_payoffs": all_welfare_pairs_wt_payoffs,
+ "solve_meta_game_after_init": False,
+ "own_player_idx": policy_idx,
+ "opp_player_idx": opp_policy_idx,
+ "own_default_welfare_fn": egalitarian_welfare_name
+ if policy_idx == 1
+ else "utilitarian",
+ "opp_default_welfare_fn": egalitarian_welfare_name
+ if opp_policy_idx == 1
+ else "utilitarian",
+ "policy_id_to_load": policy_id,
+ "policy_checkpoints": hp["base_ckpt_per_replicat"][
+ meta_game_idx
+ ],
+ "distrib_over_welfare_sets_to_annonce": hp[
+ "meta_game_policy_distributions"
+ ][meta_game_idx][policy_id],
+ }
+ )
+ policy_config_items[
+ 0
+ ] = welfare_coordination.WelfareCoordinationTorchPolicy
+ policy_config_items[3] = meta_policy_config
+ policies[policy_id] = tuple(policy_config_items)
+
+ return rllib_config
+
+
+def _change_rllib_config_to_use_stochastic_populations(
+ hp, rllib_config, meta_game_idx
+):
+ tmp_env = rllib_config["env"](rllib_config["env_config"])
+ policies = rllib_config["multiagent"]["policies"]
+ for policy_id, policy_config in policies.items():
+ policy_config = list(policy_config)
+
+ stochastic_population_policy_config = (
+ _create_one_stochastic_population_config(
+ hp, meta_game_idx, policy_id, policy_config
+ )
+ )
+
+ policy_config[0] = StochasticPopulation
+ policy_config[1] = tmp_env.OBSERVATION_SPACE
+ policy_config[2] = tmp_env.ACTION_SPACE
+ policy_config[3] = stochastic_population_policy_config
+
+ rllib_config["multiagent"]["policies"][policy_id] = tuple(
+ policy_config
+ )
+ return rllib_config
+
+
+def _create_one_stochastic_population_config(
+ hp, meta_game_idx, policy_id, policy_config
+):
+ """
+ This policy config is composed of 3 levels:
+ The top level: one stochastic population policies per player. This
+ policies stochasticly select (given some proba distribution)
+ which nested policy to use.
+ The intermediary(nested) level: one population (of identical policies) per
+ welfare function. This policy selects randomly which policy from its
+ population to use.
+ The bottom(base) level: amTFT or LOLA-Exact policies used by the
+ intermediary level. (amTFT contains another nested level)
+ """
+ stochastic_population_policy_config = {
+ "nested_policies": [],
+ "sampling_policy_distribution": hp["meta_game_policy_distributions"][
+ meta_game_idx
+ ][policy_id],
+ }
+
+ print('hp["base_ckpt_per_replicat"]', hp["base_ckpt_per_replicat"])
+ print('hp["actions_possible"]', hp["actions_possible"])
+ for welfare_i in hp["actions_possible"]:
+ one_nested_population_config = _create_one_vanilla_population_config(
+ hp,
+ policy_id,
+ copy.deepcopy(policy_config),
+ meta_game_idx,
+ welfare_i,
+ )
+
+ stochastic_population_policy_config["nested_policies"].append(
+ one_nested_population_config
+ )
+
+ return stochastic_population_policy_config
+
+
+def _create_one_vanilla_population_config(
+ hp,
+ policy_id,
+ policy_config,
+ meta_game_idx,
+ welfare_i,
+):
+ base_policy_class = copy.deepcopy(policy_config[0])
+ base_policy_config = copy.deepcopy(policy_config[3])
+
+ nested_population_config = copy.deepcopy(population.DEFAULT_CONFIG)
+ nested_population_config.update(
+ {
+ "policy_checkpoints": hp["base_ckpt_per_replicat"][meta_game_idx][
+ welfare_i
+ ],
+ "nested_policies": [
+ {
+ "Policy_class": base_policy_class,
+ "config_update": base_policy_config,
+ }
+ ],
+ "policy_id_to_load": policy_id,
+ }
+ )
+
+ intermediary_config = {
+ "Policy_class": population.PopulationOfIdenticalAlgo,
+ "config_update": nested_population_config,
+ }
+
+ return intermediary_config
+
+
+def _extract_metrics(tune_analysis, hp_eval):
+ player_1_payoffs = utils.tune_analysis.extract_value_from_last_training_iteration_for_each_trials(
+ tune_analysis, metric=hp_eval["x_axis_metric"]
+ )
+ player_2_payoffs = utils.tune_analysis.extract_value_from_last_training_iteration_for_each_trials(
+ tune_analysis, metric=hp_eval["y_axis_metric"]
+ )
+ print("player_1_payoffs", player_1_payoffs)
+ print("player_2_payoffs", player_2_payoffs)
+ return player_1_payoffs, player_2_payoffs
+
+
+def _format_result_for_plotting(results):
+ data_groups_per_mode = {}
+ df_rows = []
+ for player1_avg_r_one_replicate, player2_avg_r_one_replicate in results:
+ df_row_dict = {
+ "": (
+ player1_avg_r_one_replicate,
+ player2_avg_r_one_replicate,
+ )
+ }
+ df_rows.append(df_row_dict)
+ data_groups_per_mode["cross-play"] = pd.DataFrame(df_rows)
+ return data_groups_per_mode
+
+
+def _train_meta_policy_using_alpha_rank(hp, pure_strategy=False):
+ payoff_matrices = copy.deepcopy(hp["payoff_matrices"])
+
+ policies = []
+ for payoff_matrix in payoff_matrices:
+
+ payoff_tables_per_player = [
+ payoff_matrix[:, :, 0],
+ payoff_matrix[:, :, 1],
+ ]
+ policy_player_1, policy_player_2 = _compute_policy_wt_alpha_rank(
+ payoff_tables_per_player
+ )
+
+ if pure_strategy:
+ policy_player_1 = policy_player_1 == policy_player_1.max()
+ policy_player_2 = policy_player_2 == policy_player_2.max()
+ policy_player_1 = policy_player_1.float()
+ policy_player_2 = policy_player_2.float()
+
+ policies.append(
+ {"player_row": policy_player_1, "player_col": policy_player_2}
+ )
+ print("alpha rank meta policies", policies)
+ return policies
+
+
+def _compute_policy_wt_alpha_rank(payoff_tables_per_player):
+ from open_spiel.python.egt import alpharank
+ from open_spiel.python.algorithms.psro_v2 import utils as psro_v2_utils
+
+ joint_arank, alpha = alpharank.sweep_pi_vs_alpha(
+ payoff_tables_per_player, return_alpha=True
+ )
+ print("alpha selected", alpha)
+ (
+ policy_player_1,
+ policy_player_2,
+ ) = psro_v2_utils.get_alpharank_marginals(
+ payoff_tables_per_player, joint_arank
+ )
+ print("policy_player_1", policy_player_1)
+ print("policy_player_2", policy_player_2)
+ policy_player_1 = torch.tensor(policy_player_1)
+ policy_player_2 = torch.tensor(policy_player_2)
+ return policy_player_1, policy_player_2
+
+
+def _train_meta_policy_using_replicator_dynamic(hp):
+ from open_spiel.python.algorithms.projected_replicator_dynamics import (
+ projected_replicator_dynamics,
+ )
+
+ payoff_matrices = copy.deepcopy(hp["payoff_matrices"])
+ policies = []
+ for payoff_matrix in payoff_matrices:
+ payoff_tables_per_player = [
+ payoff_matrix[:, :, 0],
+ payoff_matrix[:, :, 1],
+ ]
+ num_actions = payoff_matrix.shape[0]
+ prd_initial_strategies = [
+ np.random.dirichlet(np.ones(num_actions) * 1.5),
+ np.random.dirichlet(np.ones(num_actions) * 1.5),
+ ]
+ if hp["meta_game_policies"] == META_REPLICATOR_DYNAMIC_ZERO_INIT:
+ policy_player_1, policy_player_2 = projected_replicator_dynamics(
+ payoff_tables_per_player,
+ prd_gamma=0.0,
+ )
+ else:
+ print("prd_initial_strategies", prd_initial_strategies)
+ policy_player_1, policy_player_2 = projected_replicator_dynamics(
+ payoff_tables_per_player,
+ prd_gamma=0.0,
+ prd_initial_strategies=prd_initial_strategies,
+ )
+
+ policy_player_1 = torch.tensor(policy_player_1)
+ policy_player_2 = torch.tensor(policy_player_2)
+ policies.append(
+ {"player_row": policy_player_1, "player_col": policy_player_2}
+ )
+ print("replicator dynamic meta policies", policies)
+ return policies
+
+
+def _get_random_meta_policy(hp):
+ payoff_matrices = copy.deepcopy(hp["payoff_matrices"])
+ policies = []
+ for payoff_matrix in payoff_matrices:
+ num_actions_player_0 = payoff_matrix.shape[0]
+ num_actions_player_1 = payoff_matrix.shape[1]
+
+ policy_player_1 = (
+ torch.ones(size=(num_actions_player_0,)) / num_actions_player_0
+ )
+ policy_player_2 = (
+ torch.ones(size=(num_actions_player_1,)) / num_actions_player_1
+ )
+ policies.append(
+ {"player_row": policy_player_1, "player_col": policy_player_2}
+ )
+ print("random meta policies", policies)
+ return policies
+
+
+if __name__ == "__main__":
+ debug_mode = True
+ loop_over_main = True
+
+ if loop_over_main:
+ base_game_algo_to_eval = (BASE_LOLAExact,)
+ meta_game_algo_to_eval = (
+ # META_APLHA_RANK,
+ # META_APLHA_PURE,
+ # META_REPLICATOR_DYNAMIC,
+ # META_REPLICATOR_DYNAMIC_ZERO_INIT,
+ META_RANDOM,
+ META_PG,
+ META_LOLAExact,
+ META_SOS,
+ )
+ for base_game_algo in base_game_algo_to_eval:
+ for meta_game_algo in meta_game_algo_to_eval:
+ main(debug_mode, base_game_algo, meta_game_algo)
+ else:
+ main(debug_mode)
diff --git a/marltoolbox/scripts/aggregate_and_plot_tensorboard_data.py b/marltoolbox/scripts/aggregate_and_plot_tensorboard_data.py
index 440db1c..4acb0ba 100644
--- a/marltoolbox/scripts/aggregate_and_plot_tensorboard_data.py
+++ b/marltoolbox/scripts/aggregate_and_plot_tensorboard_data.py
@@ -14,45 +14,58 @@
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
-from tensorboard.backend.event_processing.event_accumulator import \
- EventAccumulator
-
-from marltoolbox.utils.miscellaneous import \
- list_all_files_in_one_dir_tree, ignore_str_containing_keys, \
- separate_str_in_group_containing_keys, GROUP_KEY_NONE, \
- keep_strs_containing_keys, fing_longer_substr
-from marltoolbox.utils.plot import \
- LOWER_ENVELOPE_SUFFIX, UPPER_ENVELOPE_SUFFIX, PlotHelper, PlotConfig
-
-FOLDER_NAME = 'aggregates'
-AGGREGATION_OPS = {"mean": np.mean,
- "min": np.min,
- "max": np.max,
- "median": np.median,
- "std": np.std,
- "var": np.var}
+from tensorboard.backend.event_processing.event_accumulator import (
+ EventAccumulator,
+)
+
+from marltoolbox.utils.miscellaneous import (
+ ignore_str_containing_keys,
+ separate_str_in_group_containing_keys,
+ GROUP_KEY_NONE,
+ keep_strs_containing_keys,
+ fing_longer_substr,
+)
+from marltoolbox.utils.path import list_all_files_in_one_dir_tree
+from marltoolbox.utils.plot import (
+ LOWER_ENVELOPE_SUFFIX,
+ UPPER_ENVELOPE_SUFFIX,
+ PlotHelper,
+ PlotConfig,
+)
+
+FOLDER_NAME = "aggregates"
+AGGREGATION_OPS = {
+ "mean": np.mean,
+ "min": np.min,
+ "max": np.max,
+ "median": np.median,
+ "std": np.std,
+ "var": np.var,
+}
COLORS = list(mcolors.TABLEAU_COLORS)
-PLOT_KEYS = ["grad_gnorm",
- "reward",
- "loss",
- "entropy",
- "entropy_avg",
- "td_error",
- "error",
- "act_dist_inputs_avg",
- "act_dist_inputs_single",
- "q_values_avg",
- "action_prob",
- "q_values_single",
- "_lr",
- "max_q_values",
- "min_q_values",
- "learn_on_batch",
- "timers",
- "ms",
- "throughput",
- ]
+PLOT_KEYS = [
+ "grad_gnorm",
+ "reward",
+ "loss",
+ "entropy",
+ "entropy_avg",
+ "td_error",
+ "error",
+ "act_dist_inputs_avg",
+ "act_dist_inputs_single",
+ "q_values_avg",
+ "action_prob",
+ "q_values_single",
+ "_lr",
+ "max_q_values",
+ "min_q_values",
+ "learn_on_batch",
+ "timers",
+ "ms",
+ "throughput",
+ "temperature",
+]
PLOT_ASSEMBLAGE_TAGS = [
("policy_reward_mean",),
@@ -82,22 +95,20 @@
("ms",),
("throughput",),
("_lr",),
+ ("temperature",),
]
-class TensorBoardDataExtractor():
-
+class TensorBoardDataExtractor:
def __init__(self, main_path):
self.main_path = main_path
now = datetime.datetime.now()
self.date_hour_str = now.strftime("%Y_%m_%d_%H_%M_%S")
- save_dir = os.path.join(self.main_path,
- FOLDER_NAME)
+ save_dir = os.path.join(self.main_path, FOLDER_NAME)
if not os.path.exists(save_dir):
os.mkdir(save_dir)
- save_dir = os.path.join(save_dir,
- self.date_hour_str)
+ save_dir = os.path.join(save_dir, self.date_hour_str)
if not os.path.exists(save_dir):
os.mkdir(save_dir)
@@ -106,10 +117,10 @@ def __init__(self, main_path):
def extract_data(self, ignore_keys, group_keys, output):
print("\n===== Extract data =====")
file_list = list_all_files_in_one_dir_tree(self.main_path)
- file_list_filtered = ignore_str_containing_keys(
- file_list, ignore_keys)
+ file_list_filtered = ignore_str_containing_keys(file_list, ignore_keys)
file_list_dict = separate_str_in_group_containing_keys(
- file_list_filtered, group_keys)
+ file_list_filtered, group_keys
+ )
self._aggregate(self.main_path, output, file_list_dict)
return self.save_dir
@@ -120,8 +131,7 @@ def _aggregate(self, main_path, output, file_list_dict):
extracts_per_group = {
group_key: self._extract_x_y_per_keys(main_path, file_list)
- for group_key, file_list in
- file_list_dict.items()
+ for group_key, file_list in file_list_dict.items()
}
if output == "summary":
@@ -130,25 +140,35 @@ def _aggregate(self, main_path, output, file_list_dict):
# https://github.com/Spenhouet/tensorboard-aggregator
elif output == "csv":
self._aggregate_to_csv(
- main_path, AGGREGATION_OPS, extracts_per_group)
+ main_path, AGGREGATION_OPS, extracts_per_group
+ )
print(f"End of aggregation {main_path}")
def _extract_x_y_per_keys(self, main_path, file_list):
- print("Going to extract", main_path,
- "with len(file_list)", len(file_list))
-
- event_readers = \
- self._create_event_reader_for_each_log_files(file_list)
+ print(
+ "Going to extract",
+ main_path,
+ "with len(file_list)",
+ len(file_list),
+ )
+
+ event_readers = self._create_event_reader_for_each_log_files(file_list)
if len(event_readers) == 0:
return None
- all_scalar_events_per_key, keys = \
- self._get_and_validate_all_scalar_keys(event_readers)
- steps_per_key, all_scalar_events_per_key = \
- self._get_and_validate_all_steps_per_key(
- all_scalar_events_per_key, keys)
- values_per_key = \
- self._get_values_per_step_per_key(all_scalar_events_per_key)
+ (
+ all_scalar_events_per_key,
+ keys,
+ ) = self._get_and_validate_all_scalar_keys(event_readers)
+ (
+ steps_per_key,
+ all_scalar_events_per_key,
+ ) = self._get_and_validate_all_steps_per_key(
+ all_scalar_events_per_key, keys
+ )
+ values_per_key = self._get_values_per_step_per_key(
+ all_scalar_events_per_key
+ )
keys = [key.replace("/", "_") for key in keys]
all_per_key = dict(zip(keys, zip(steps_per_key, values_per_key)))
@@ -156,27 +176,32 @@ def _extract_x_y_per_keys(self, main_path, file_list):
return all_per_key
def _create_event_reader_for_each_log_files(self, file_list):
- event_readers = [EventAccumulator(file_path).Reload(
- ).scalars for file_path in file_list]
+ event_readers = [
+ EventAccumulator(file_path).Reload().scalars
+ for file_path in file_list
+ ]
# Filter non event files
- event_readers = [one_event_reader for one_event_reader in event_readers
- if
- one_event_reader.Keys()]
+ event_readers = [
+ one_event_reader
+ for one_event_reader in event_readers
+ if one_event_reader.Keys()
+ ]
print(f"found {len(event_readers)} event_readers")
return event_readers
def _get_and_validate_all_scalar_keys(self, event_readers):
- all_keys = [tuple(one_event_reader.Keys())
- for one_event_reader in event_readers]
+ all_keys = [
+ tuple(one_event_reader.Keys())
+ for one_event_reader in event_readers
+ ]
self._print_discrepencies_in_keys(all_keys)
keys = self._get_common_keys(all_keys)
all_scalar_events_per_key = [
- [one_event_reader.Items(key)
- for one_event_reader in event_readers]
+ [one_event_reader.Items(key) for one_event_reader in event_readers]
for key in keys
]
return all_scalar_events_per_key, keys
@@ -189,8 +214,10 @@ def _print_discrepencies_in_keys(self, all_keys):
for k in keys_1:
if k not in keys_2:
if k not in missing_k_detected:
- print(f"key {k} is not present in all "
- f"event_readers")
+ print(
+ f"key {k} is not present in all "
+ f"event_readers"
+ )
missing_k_detected.append(k)
def _get_common_keys(self, all_keys):
@@ -207,21 +234,27 @@ def _get_common_keys(self, all_keys):
return common_keys
def _get_and_validate_all_steps_per_key(
- self, all_scalar_events_per_key, keys):
+ self, all_scalar_events_per_key, keys
+ ):
all_steps_per_key = [
- [tuple(scalar_event.step
- for scalar_event in scalar_events)
- for scalar_events in all_scalar_events]
- for all_scalar_events in all_scalar_events_per_key]
+ [
+ tuple(scalar_event.step for scalar_event in scalar_events)
+ for scalar_events in all_scalar_events
+ ]
+ for all_scalar_events in all_scalar_events_per_key
+ ]
steps_per_key = []
- for key_idx, (all_steps_for_one_key, key) in enumerate(zip(
- all_steps_per_key, keys)):
+ for key_idx, (all_steps_for_one_key, key) in enumerate(
+ zip(all_steps_per_key, keys)
+ ):
self._print_discrepencies_in_steps(all_steps_for_one_key, key)
common_steps = self._keep_common_steps(all_steps_for_one_key)
- all_scalar_events_per_key = \
+ all_scalar_events_per_key = (
self._remove_events_if_step_missing_somewhere(
- common_steps, all_scalar_events_per_key, key_idx)
+ common_steps, all_scalar_events_per_key, key_idx
+ )
+ )
steps_per_key.append(common_steps)
return steps_per_key, all_scalar_events_per_key
@@ -229,12 +262,14 @@ def _get_and_validate_all_steps_per_key(
def _print_discrepencies_in_steps(self, all_steps_for_one_key, key):
for steps_1 in all_steps_for_one_key:
for steps_2 in all_steps_for_one_key:
- missing_steps = [step
- for step in steps_1
- if step not in steps_2]
+ missing_steps = [
+ step for step in steps_1 if step not in steps_2
+ ]
if len(missing_steps) > 0:
- print(f"discrepency in steps logged for key {key}:"
- f"{missing_steps} missing")
+ print(
+ f"discrepency in steps logged for key {key}:"
+ f"{missing_steps} missing"
+ )
break
def _keep_common_steps(self, all_steps_for_one_key):
@@ -251,62 +286,76 @@ def _keep_common_steps(self, all_steps_for_one_key):
return common_steps
def _remove_events_if_step_missing_somewhere(
- self, common_steps, all_scalar_events_per_key, key_idx):
+ self, common_steps, all_scalar_events_per_key, key_idx
+ ):
all_scalar_events_per_key[key_idx] = [
- [scalar_event for scalar_event in scalar_events_batch
- if scalar_event.step in common_steps]
+ [
+ scalar_event
+ for scalar_event in scalar_events_batch
+ if scalar_event.step in common_steps
+ ]
for scalar_events_batch in all_scalar_events_per_key[key_idx]
]
return all_scalar_events_per_key
def _get_values_per_step_per_key(self, all_scalar_events_per_key):
values_per_key = [
- [[scalar_event.value
- for scalar_event in scalar_events]
- for scalar_events in all_scalar_events]
- for all_scalar_events in all_scalar_events_per_key]
+ [
+ [scalar_event.value for scalar_event in scalar_events]
+ for scalar_events in all_scalar_events
+ ]
+ for all_scalar_events in all_scalar_events_per_key
+ ]
return values_per_key
- def _aggregate_to_csv(self, main_path, aggregation_ops,
- extracts_per_group):
+ def _aggregate_to_csv(
+ self, main_path, aggregation_ops, extracts_per_group
+ ):
for group_key, all_per_key in extracts_per_group.items():
if all_per_key is None:
continue
for key, (steps, values) in all_per_key.items():
- aggregations = {key: aggregation_operation(values, axis=0)
- for key, aggregation_operation in
- aggregation_ops.items()}
+ aggregations = {
+ key: aggregation_operation(values, axis=0)
+ for key, aggregation_operation in aggregation_ops.items()
+ }
self._write_csv(main_path, group_key, key, aggregations, steps)
def _write_csv(self, main_path, group_key, key, aggregations, steps):
main_path_split = os.path.split(main_path)
group_dir = self._get_valid_filename(group_key)
- save_group_dir = os.path.join(self.save_dir, group_dir) \
- if group_key != GROUP_KEY_NONE else self.save_dir
+ save_group_dir = (
+ os.path.join(self.save_dir, group_dir)
+ if group_key != GROUP_KEY_NONE
+ else self.save_dir
+ )
if not os.path.exists(save_group_dir):
os.mkdir(save_group_dir)
- file_name = self._get_valid_filename(key) + '-' + \
- main_path_split[-1] + '.csv'
+ file_name = (
+ self._get_valid_filename(key) + "-" + main_path_split[-1] + ".csv"
+ )
df = pd.DataFrame(aggregations, index=steps)
save_dir_file_path = os.path.join(save_group_dir, file_name)
- df.to_csv(save_dir_file_path, sep=';')
+ df.to_csv(save_dir_file_path, sep=";")
def _get_valid_filename(self, s):
- s = str(s).strip().replace(' ', '_')
- return re.sub(r'(?u)[^-\w.]', '', s)
-
-
-class SummaryPlotter():
- def plot_selected_keys(self,
- save_dir,
- plot_keys,
- group_keys,
- plot_aggregates,
- plot_assemble_tags_in_one_plot,
- plot_single_lines,
- plot_labels_cleaning,
- additional_plot_config_kwargs):
+ s = str(s).strip().replace(" ", "_")
+ return re.sub(r"(?u)[^-\w.]", "", s)
+
+
+class SummaryPlotter:
+ def plot_selected_keys(
+ self,
+ save_dir,
+ plot_keys,
+ group_keys,
+ plot_aggregates,
+ plot_assemble_tags_in_one_plot,
+ plot_single_lines,
+ plot_labels_cleaning,
+ additional_plot_config_kwargs,
+ ):
self.plot_labels_cleaning = plot_labels_cleaning
self.plot_aggregates = plot_aggregates
@@ -316,15 +365,19 @@ def plot_selected_keys(self,
save_dir_path = save_dir
file_list = list_all_files_in_one_dir_tree(save_dir_path)
file_list = keep_strs_containing_keys(file_list, plot_keys)
- csv_file_list = [file_path
- for file_path in file_list
- if "csv" in file_path]
+ csv_file_list = [
+ file_path for file_path in file_list if "csv" in file_path
+ ]
csv_file_groups = separate_str_in_group_containing_keys(
- csv_file_list, group_keys)
+ csv_file_list, group_keys
+ )
for group_key, csv_files_in_one_group in csv_file_groups.items():
- save_dir_path_group = os.path.join(save_dir_path, group_key) \
- if group_key != GROUP_KEY_NONE else save_dir_path
+ save_dir_path_group = (
+ os.path.join(save_dir_path, group_key)
+ if group_key != GROUP_KEY_NONE
+ else save_dir_path
+ )
if not os.path.exists(save_dir_path_group):
os.mkdir(save_dir_path_group)
@@ -336,8 +389,10 @@ def plot_selected_keys(self,
print("===== Plot assemblages =====")
self.plot_several_lines_per_plot(
- save_dir_path_group, csv_files_in_one_group,
- plot_assemble_tags_in_one_plot)
+ save_dir_path_group,
+ csv_files_in_one_group,
+ plot_assemble_tags_in_one_plot,
+ )
def plot_one_graph(self, save_dir_path, csv_file_list, y_label=None):
data_groups = {}
@@ -349,33 +404,40 @@ def plot_one_graph(self, save_dir_path, csv_file_list, y_label=None):
if "min_max" in self.plot_aggregates:
assert "one_std" not in self.plot_aggregates
- df = df.rename(columns={'min': f'mean{LOWER_ENVELOPE_SUFFIX}',
- 'max': f'mean{UPPER_ENVELOPE_SUFFIX}'})
+ df = df.rename(
+ columns={
+ "min": f"mean{LOWER_ENVELOPE_SUFFIX}",
+ "max": f"mean{UPPER_ENVELOPE_SUFFIX}",
+ }
+ )
else:
- df = df.drop(columns=['min', 'max'])
+ df = df.drop(columns=["min", "max"])
if "one_std" in self.plot_aggregates:
assert "min_max" not in self.plot_aggregates
- df[f'mean{LOWER_ENVELOPE_SUFFIX}'] = df['mean'] - df['std']
- df[f'mean{UPPER_ENVELOPE_SUFFIX}'] = df['mean'] + df['std']
- df = df.drop(columns=['std', 'var', "median"])
+ df[f"mean{LOWER_ENVELOPE_SUFFIX}"] = df["mean"] - df["std"]
+ df[f"mean{UPPER_ENVELOPE_SUFFIX}"] = df["mean"] + df["std"]
+ df = df.drop(columns=["std", "var", "median"])
data_groups[tag] = df
plot_options = PlotConfig(
xlabel="steps",
ylabel=fing_longer_substr(all_tags_seen).strip("_")
- if y_label is None else y_label,
+ if y_label is None
+ else y_label,
save_dir_path=save_dir_path,
- **self.additional_plot_config_kwargs)
+ **self.additional_plot_config_kwargs,
+ )
plot_helper = PlotHelper(plot_options)
plot_helper.plot_lines(data_groups)
def plot_several_lines_per_plot(
- self, save_dir_path, csv_file_list,
- plot_assemble_tags_in_one_plot):
+ self, save_dir_path, csv_file_list, plot_assemble_tags_in_one_plot
+ ):
- for assemblage_idx, list_of_tags_in_assemblage in \
- enumerate(plot_assemble_tags_in_one_plot):
+ for assemblage_idx, list_of_tags_in_assemblage in enumerate(
+ plot_assemble_tags_in_one_plot
+ ):
assert isinstance(list_of_tags_in_assemblage, Iterable)
# select files for one assemblage
assemblage_list = self._group_csv_file_in_aggregates(
@@ -384,19 +446,25 @@ def plot_several_lines_per_plot(
if len(assemblage_list) > 0:
# plot one assemblage
- y_label = f"{assemblage_idx}_" + \
- " or ".join(list_of_tags_in_assemblage)
+ y_label = f"{assemblage_idx}_" + " or ".join(
+ list_of_tags_in_assemblage
+ )
self.plot_one_graph(
- save_dir_path, assemblage_list,
- y_label=y_label)
+ save_dir_path, assemblage_list, y_label=y_label
+ )
def _group_csv_file_in_aggregates(
- self, csv_file_list, list_of_tags_in_assemblage):
+ self, csv_file_list, list_of_tags_in_assemblage
+ ):
print(f"Start the {list_of_tags_in_assemblage} assemblage")
assemblage_list = []
for csv_file in csv_file_list:
- if any([select_key in csv_file
- for select_key in list_of_tags_in_assemblage]):
+ if any(
+ [
+ select_key in csv_file
+ for select_key in list_of_tags_in_assemblage
+ ]
+ ):
assemblage_list.append(csv_file)
# print("csv files selected for assemblage", assemblage_list)
assemblage_list = sorted(assemblage_list)
@@ -418,31 +486,36 @@ def extract_tag_from_file_name(self, csv_file):
return tag
-def add_summary_plots(main_path: str,
- ignore_keys: Iterable = (
- "aggregates", "same_cross_play"),
- group_keys: Iterable = (),
- output: str = "csv",
- plot_keys: Iterable = ("policy_reward_mean",
- "loss",
- "entropy",
- "entropy_avg",
- "td_error"),
- plot_aggregates: Iterable = ("mean", "min_max"),
- plot_assemble_tags_in_one_plot=(("policy_reward_mean",),
- ("loss", "td_error"),
- ("entropy",),
- ("entropy_avg",)),
- plot_single_lines=False,
- plot_labels_cleaning: Iterable = (
- ("learner_stats_", ""),
- ("info_learner_", ""),
- ("player_", "pl_")
- ),
- additional_plot_config_kwargs={
- "figsize": (8, 8),
- "legend_fontsize": "small"},
- ):
+def add_summary_plots(
+ main_path: str,
+ ignore_keys: Iterable = ("aggregates", "same_cross_play"),
+ group_keys: Iterable = (),
+ output: str = "csv",
+ plot_keys: Iterable = (
+ "policy_reward_mean",
+ "loss",
+ "entropy",
+ "entropy_avg",
+ "td_error",
+ ),
+ plot_aggregates: Iterable = ("mean", "min_max"),
+ plot_assemble_tags_in_one_plot=(
+ ("policy_reward_mean",),
+ ("loss", "td_error"),
+ ("entropy",),
+ ("entropy_avg",),
+ ),
+ plot_single_lines=False,
+ plot_labels_cleaning: Iterable = (
+ ("learner_stats_", ""),
+ ("info_learner_", ""),
+ ("player_", "pl_"),
+ ),
+ additional_plot_config_kwargs={
+ "figsize": (8, 8),
+ "legend_fontsize": "small",
+ },
+):
"""
Aggregates multiple tensorboard runs into mean, min, max, median, std and
save that in tensorboard files or in csv.
@@ -489,9 +562,11 @@ def add_summary_plots(main_path: str,
:return:
"""
- if output not in ['summary', 'csv']:
- raise ValueError("output must be one of ['summary', 'csv']"
- f"current output: {output}")
+ if output not in ["summary", "csv"]:
+ raise ValueError(
+ "output must be one of ['summary', 'csv']"
+ f"current output: {output}"
+ )
main_path = os.path.expanduser(main_path)
@@ -501,9 +576,15 @@ def add_summary_plots(main_path: str,
if output == "csv":
plotter = SummaryPlotter()
plotter.plot_selected_keys(
- save_dir, plot_keys, group_keys, plot_aggregates,
- plot_assemble_tags_in_one_plot, plot_single_lines,
- plot_labels_cleaning, additional_plot_config_kwargs)
+ save_dir,
+ plot_keys,
+ group_keys,
+ plot_aggregates,
+ plot_assemble_tags_in_one_plot,
+ plot_single_lines,
+ plot_labels_cleaning,
+ additional_plot_config_kwargs,
+ )
def param_list(param):
@@ -513,42 +594,61 @@ def param_list(param):
return p_list
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Args for aggregation data
- parser.add_argument("--main_path",
- type=str,
- help="main path for tensorboard files",
- default=os.getcwd())
- parser.add_argument("--group_keys",
- type=param_list,
- help="keys used to separate files in groups",
- default=[])
- parser.add_argument("--ignore_keys",
- type=param_list,
- help="keys used to ignore files",
- default=["aggregates", "same_cross_play"])
- parser.add_argument("--output",
- type=str,
- help="aggregation can be saves as "
- "tensorboard file (summary) or as table (csv)",
- default='csv')
+ parser.add_argument(
+ "--main_path",
+ type=str,
+ help="main path for tensorboard files",
+ default=os.getcwd(),
+ )
+ parser.add_argument(
+ "--group_keys",
+ type=param_list,
+ help="keys used to separate files in groups",
+ default=[],
+ )
+ parser.add_argument(
+ "--ignore_keys",
+ type=param_list,
+ help="keys used to ignore files",
+ default=["aggregates", "same_cross_play"],
+ )
+ parser.add_argument(
+ "--output",
+ type=str,
+ help="aggregation can be saves as "
+ "tensorboard file (summary) or as table (csv)",
+ default="csv",
+ )
# Args for plotting
- parser.add_argument("--plot_keys",
- type=param_list,
- help="keys used to select tensorboard tags to plot",
- default=['reward', 'loss', 'entropy'])
- parser.add_argument("--plot_aggregates",
- type=param_list,
- help="which results of aggregation operations to plot",
- default=['mean', 'min_max'])
- parser.add_argument("--plot_assemble_tags_in_one_plot",
- type=param_list,
- help="keys used to select tensorboard tags to "
- "aggregated plots",
- default=[['reward']])
+ parser.add_argument(
+ "--plot_keys",
+ type=param_list,
+ help="keys used to select tensorboard tags to plot",
+ default=["reward", "loss", "entropy"],
+ )
+ parser.add_argument(
+ "--plot_aggregates",
+ type=param_list,
+ help="which results of aggregation operations to plot",
+ default=["mean", "min_max"],
+ )
+ parser.add_argument(
+ "--plot_assemble_tags_in_one_plot",
+ type=param_list,
+ help="keys used to select tensorboard tags to " "aggregated plots",
+ default=[["reward"]],
+ )
args = parser.parse_args()
- add_summary_plots(args.main_path, args.ignore_keys, args.group_keys,
- args.output, args.plot_keys, args.plot_aggregates,
- args.plot_assemble_tags_in_one_plot)
+ add_summary_plots(
+ args.main_path,
+ args.ignore_keys,
+ args.group_keys,
+ args.output,
+ args.plot_keys,
+ args.plot_aggregates,
+ args.plot_assemble_tags_in_one_plot,
+ )
diff --git a/marltoolbox/scripts/average_saved_results.py b/marltoolbox/scripts/average_saved_results.py
new file mode 100644
index 0000000..4fd355f
--- /dev/null
+++ b/marltoolbox/scripts/average_saved_results.py
@@ -0,0 +1,135 @@
+import json
+import os
+
+import numpy as np
+
+
+def main(debug):
+ prefix, files_data, n_players = _get_inputs()
+ files_to_process = _preprocess_inputs(prefix, files_data)
+
+ for file, file_data in zip(files_to_process, files_data):
+ (
+ mean_per_player,
+ std_dev_per_player,
+ std_err_per_player,
+ ) = _get_stats_for_file(file, n_players, file_data)
+
+ print(
+ file_data[0],
+ "mean:",
+ mean_per_player,
+ "std_dev:",
+ std_dev_per_player,
+ "std_err:",
+ std_err_per_player,
+ )
+
+
+def _get_inputs():
+ prefix = "~/dev-maxime/CLR/vm-data/"
+ files_data = (
+ ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base
+ "META(PG) & BASE(LOLA-Exact)",
+ 200,
+ "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_07/13_46_27/final_base_game/final_eval_in_base_game.json",
+ ),
+ ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base
+ "META(LOLA-Exact) & BASE(LOLA-Exact)",
+ 200,
+ "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_07/12_32_57/final_base_game/final_eval_in_base_game.json",
+ ),
+ ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base
+ "META(alpha-rank pure strategies) & BASE(LOLA-Exact)",
+ 200,
+ "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_07/12_05_00/final_base_game/final_eval_in_base_game.json",
+ ),
+ ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base
+ "META(alpha-rank mixed strategies) & BASE(LOLA-Exact)",
+ 200,
+ "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_07/12_00_06/final_base_game/final_eval_in_base_game.json",
+ ),
+ ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base
+ "META(replicator dynamic) & BASE(LOLA-Exact)",
+ 200,
+ "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_07/11_24_50/final_base_game/final_eval_in_base_game.json",
+ ),
+ ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base
+ "META(Uniform(announcement+tau=0)) & BASE(announcement + LOLA-Exact)",
+ 200,
+ "instance-60-cpu-4-preemtible/LOLA_Exact/2021_05_07/15_47_51/final_eval_in_base_game.json",
+ "2nd_format_placeholder",
+ ),
+ ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base
+ "META(alpha-rank mixed on welfare sets) & BASE(announcement + "
+ "LOLA-Exact)",
+ 200,
+ "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_11/10_10_14/meta_game/final_base_game/final_eval_in_base_game.json",
+ ),
+ ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base
+ "META(alpha-rank pure on welfare sets) & BASE(announcement + "
+ "LOLA-Exact)",
+ 200,
+ "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_11/10_12_43/meta_game/final_base_game/final_eval_in_base_game.json",
+ ),
+ ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base
+ "META(replicator dynamic random init on welfare sets) & BASE("
+ "announcement + LOLA-Exact)",
+ 200,
+ "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_11/10_15_06/meta_game/final_base_game/final_eval_in_base_game.json",
+ ),
+ ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base
+ "META(replicator dynamic default init on welfare sets) & BASE("
+ "announcement + LOLA-Exact)",
+ 200,
+ "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_11/10_19_15/meta_game/final_base_game/final_eval_in_base_game.json",
+ ),
+ )
+ n_players = 2
+ return prefix, files_data, n_players
+
+
+def _preprocess_inputs(prefix, files_data):
+ files_to_process = [
+ os.path.join(prefix, file_data[2]) for file_data in files_data
+ ]
+ return files_to_process
+
+
+def _get_stats_for_file(file, n_players, file_data):
+ file_path = os.path.expanduser(file)
+ with (open(file_path, "rb")) as f:
+ file_content = json.load(f)
+ file_content = _format_2nd_into_1st_format(file_content, file_data)
+ values_per_replicat_per_player = np.array(file_content)
+
+ assert values_per_replicat_per_player.ndim == 2
+ n_replicates_in_content = values_per_replicat_per_player.shape[0]
+ n_players_in_content = values_per_replicat_per_player.shape[1]
+ assert n_players_in_content == n_players
+
+ values_per_replicat_per_player = (
+ values_per_replicat_per_player / file_data[1]
+ )
+
+ mean_per_player = values_per_replicat_per_player.mean(axis=0)
+ std_dev_per_player = values_per_replicat_per_player.std(axis=0)
+ std_err_per_player = std_dev_per_player / np.sqrt(
+ n_replicates_in_content
+ )
+ return mean_per_player, std_dev_per_player, std_err_per_player
+
+
+def _format_2nd_into_1st_format(file_content, file_data):
+ if len(file_data) == 4:
+ file_content = file_content[0][2]
+ new_format = []
+ for p1_content, p2_content in zip(file_content[0], file_content[1]):
+ new_format.append((p1_content, p2_content))
+ file_content = new_format
+ return file_content
+
+
+if __name__ == "__main__":
+ debug_mode = False
+ main(debug_mode)
diff --git a/marltoolbox/scripts/plot_histogram_from_saved_results.py b/marltoolbox/scripts/plot_histogram_from_saved_results.py
new file mode 100644
index 0000000..241972a
--- /dev/null
+++ b/marltoolbox/scripts/plot_histogram_from_saved_results.py
@@ -0,0 +1,467 @@
+import json
+import os
+from collections import namedtuple
+
+import matplotlib.colors as mcolors
+import matplotlib.pyplot as plt
+import numpy as np
+
+plt.switch_backend("agg")
+plt.style.use("seaborn-whitegrid")
+plt.rcParams.update({"font.size": 12})
+
+COLORS = list(mcolors.TABLEAU_COLORS) + list(mcolors.XKCD_COLORS)
+RANDOM_MARKERS = ["1", "2", "3", "4", "8", "s", "p", "P", "*", "h", "H", "+"]
+MARKERS = ["o", "s", "v", "^", "<", ">", "P", "X", "D", "*"] + RANDOM_MARKERS
+
+Exp_data = namedtuple("Exp_data", ["base_algo", "env", "perf"])
+Perf = namedtuple("Perf", ["mean", "std_dev", "std_err", "raw"])
+File_data = namedtuple(
+ "File_data",
+ [
+ "base_algo",
+ "env",
+ "reward_adaptation_divider",
+ "path_to_self_play",
+ "path_to_preferences",
+ "max_r_by_players",
+ "min_r_by_players",
+ ],
+)
+
+NA = "N/A"
+
+PLAYER_0 = 0
+PLAYER_1 = 1
+
+
+def main(debug):
+ prefix, files_data, n_players = _get_inputs()
+ files_to_process = _preprocess_inputs(prefix, files_data)
+
+ perf_per_mode_per_files = []
+ for file_paths, file_data in zip(files_to_process, files_data):
+ perf_per_mode = _get_stats(file_paths, n_players, file_data)
+ perf_per_mode_per_files.append(
+ Exp_data(file_data.base_algo, file_data.env, perf_per_mode)
+ )
+
+ _plot_bars(perf_per_mode_per_files)
+
+
+IPD_MAX = (-1, -1)
+IPD_MIN = (-3, -3)
+ASYMIBOS_MAX = ((4 + 2) / 2.0, (2 + 1) / 2.0)
+ASYMIBOS_MIN = (0, 0)
+N_CELLS_AT_1_STEP = 4
+N_CELLS_AT_2_STEPS = 4
+N_CELLS_EXCLUDING_CURRENT = 8
+MAX_PICK_SPEED = (
+ N_CELLS_AT_1_STEP / N_CELLS_EXCLUDING_CURRENT / 1
+ + N_CELLS_AT_2_STEPS / N_CELLS_EXCLUDING_CURRENT / 2
+) # 0.75
+CG_MAX = (1.0 * MAX_PICK_SPEED / 2.0, 1.0 * MAX_PICK_SPEED / 2.0)
+CG_MIN = (0, 0)
+MCPCG_MAX = (
+ (2 / 2.0 + 1 / 2.0) * MAX_PICK_SPEED / 2.0,
+ (3 / 2.0 + 1 / 2.0) * MAX_PICK_SPEED / 2.0,
+)
+MCPCG_MIN = (0, 0)
+
+LOLA_EXACT_WT_IPD_IDX = 1
+N_NO_MCP = 4
+
+
+def _get_inputs():
+ prefix = "~/dev-maxime/CLR/vm-data/"
+ files_data = (
+ File_data(
+ "amTFT",
+ "IPD",
+ 20.0,
+ "instance-60-cpu-1"
+ "-preemtible/amTFT/2021_05_11/07_31_41/eval/2021_05_11/09_20_44"
+ "/plot_self_crossself_and_cross_play_policy_reward_mean_player_col_vs_policy_reward_mean_player_row_matrix.json",
+ "instance-60-cpu-1"
+ "-preemtible/amTFT/2021_05_11/07_31_41/eval/2021_05_11/09_20_44"
+ "/plot_same_and_diff_prefself_and_cross_play_policy_reward_mean_player_col_vs_policy_reward_mean_player_row_matrix.json",
+ IPD_MAX,
+ IPD_MIN,
+ ),
+ File_data(
+ "LOLA-Exact",
+ "IPD",
+ 200.0,
+ "instance-60-cpu-1"
+ "-preemtible/LOLA_Exact/2021_05_11/07_46_03/eval/2021_05_11"
+ "/07_49_14"
+ "/plot_self_crossself_and_cross_play_policy_reward_mean_player_col_vs_policy_reward_mean_player_row_matrix.json",
+ None,
+ # "instance-60-cpu-1"
+ # "-preemtible/LOLA_Exact/2021_05_11/07_46_03/eval/2021_05_11/07_49_14"
+ # "/plot_same_and_diff_prefself_and_cross_play_policy_reward_mean_player_col_vs_policy_reward_mean_player_row_matrix.json",
+ IPD_MAX,
+ IPD_MIN,
+ ),
+ File_data(
+ "amTFT",
+ "CG",
+ 100.0,
+ "instance-60-cpu-3-preemtible"
+ "/amTFT/2021_05_11/08_02_22/eval/2021_05_11/17_32_31"
+ "/plot_self_crossself_and_cross_play_policy_reward_mean_player_blue_vs_policy_reward_mean_player_red_matrix.json",
+ "instance-60-cpu-3-preemtible"
+ "/amTFT/2021_05_11/08_02_22/eval/2021_05_11/17_32_31"
+ "/plot_same_and_diff_prefself_and_cross_play_policy_reward_mean_player_blue_vs_policy_reward_mean_player_red_matrix.json",
+ CG_MAX,
+ CG_MIN,
+ ),
+ File_data(
+ "LOLA-PG",
+ "CG",
+ 40.0,
+ "instance-10-cpu-8-memory-x2"
+ "/LOLA_PG/2021_05_12/10_09_50/eval/2021_05_12/21_38_50"
+ "/plot_self_crossself_and_cross_play_policy_reward_mean_player_blue_vs_policy_reward_mean_player_red_matrix.json",
+ None,
+ # "instance-10-cpu-8-memory-x2"
+ # "/LOLA_PG/2021_05_12/10_09_50/eval/2021_05_12/21_38_50"
+ # "/plot_same_and_diff_prefself_and_cross_play_policy_reward_mean_player_blue_vs_policy_reward_mean_player_red_matrix.json",
+ CG_MAX,
+ CG_MIN,
+ ),
+ File_data(
+ "amTFT",
+ "AsymIBoS",
+ 20.0,
+ "instance-60-cpu-1"
+ "-preemtible/amTFT/2021_05_11/07_40_04/eval/2021_05_11/11_43_26"
+ "/plot_self_crossself_and_cross_play_policy_reward_mean_player_col_vs_policy_reward_mean_player_row_matrix.json",
+ "instance-60-cpu-1"
+ "-preemtible/amTFT/2021_05_11/07_40_04/eval/2021_05_11/11_43_26"
+ "/plot_same_and_diff_prefself_and_cross_play_policy_reward_mean_player_col_vs_policy_reward_mean_player_row_matrix.json",
+ ASYMIBOS_MAX,
+ ASYMIBOS_MIN,
+ ),
+ File_data(
+ "LOLA-Exact",
+ "AsymIBoS",
+ 200.0,
+ "instance-60-cpu-1"
+ "-preemtible/LOLA_Exact/2021_05_11/07_47_16/eval/2021_05_11"
+ "/07_50_36"
+ "/plot_self_crossself_and_cross_play_policy_reward_mean_player_col_vs_policy_reward_mean_player_row_matrix.json",
+ "instance-60-cpu-1"
+ "-preemtible/LOLA_Exact/2021_05_11/07_47_16/eval/2021_05_11/07_50_36"
+ "/plot_same_and_diff_prefself_and_cross_play_policy_reward_mean_player_col_vs_policy_reward_mean_player_row_matrix.json",
+ ASYMIBOS_MAX,
+ ASYMIBOS_MIN,
+ ),
+ File_data(
+ "amTFT",
+ "MCPCG",
+ 100.0,
+ "instance-10-cpu-7-memory-x2"
+ "/amTFT/2021_05_09/08_20_23/eval/2021_05_09/23_21_35"
+ "/plot_self_crossself_and_cross_play_policy_reward_mean_player_blue_vs_policy_reward_mean_player_red_matrix.json",
+ "instance-10-cpu-7-memory-x2"
+ "/amTFT/2021_05_09/08_20_23/eval/2021_05_09/23_21_35"
+ "/plot_same_and_diff_prefself_and_cross_play_policy_reward_mean_player_blue_vs_policy_reward_mean_player_red_matrix.json",
+ MCPCG_MAX,
+ MCPCG_MIN,
+ ),
+ File_data(
+ "LOLA-PG",
+ "MCPCG",
+ 40.0,
+ "instance-60-cpu-3-preemtible"
+ "/LOLA_PG/2021_05_09/15_54_03/eval/2021_05_09/23_12_34"
+ "/plot_self_crossself_and_cross_play_policy_reward_mean_player_blue_vs_policy_reward_mean_player_red_matrix.json",
+ "instance-60-cpu-3-preemtible"
+ "/LOLA_PG/2021_05_09/15_54_03/eval/2021_05_09/23_12_34"
+ "/plot_same_and_diff_prefself_and_cross_play_policy_reward_mean_player_blue_vs_policy_reward_mean_player_red_matrix.json",
+ MCPCG_MAX,
+ MCPCG_MIN,
+ ),
+ )
+ n_players = 2
+ return prefix, files_data, n_players
+
+
+def _preprocess_inputs(prefix, files_data):
+ files_to_process = [
+ (
+ os.path.join(prefix, file_data.path_to_self_play),
+ os.path.join(prefix, file_data.path_to_preferences)
+ if file_data.path_to_preferences is not None
+ else None,
+ )
+ for file_data in files_data
+ ]
+ return files_to_process
+
+
+def _get_stats(file_paths, n_players, file_data):
+ self_play_path = file_paths[0]
+ perf_per_mode = _get_stats_for_file(self_play_path, n_players, file_data)
+ print("perf_per_mode", perf_per_mode)
+ self_play = perf_per_mode["self-play"]
+ cross_play = perf_per_mode["cross-play"]
+
+ preference_path = file_paths[1]
+ if preference_path is not None:
+ perf_per_mode_bis = _get_stats_for_file(
+ preference_path, n_players, file_data
+ )
+ same_preferences_cross_play = perf_per_mode_bis[
+ "cross-play: same pref vs same pref"
+ ]
+ if "cross-play: diff pref vs diff pref" in perf_per_mode_bis.keys():
+ diff_preferences_cross_play = perf_per_mode_bis[
+ "cross-play: diff pref vs diff pref"
+ ]
+ else:
+ diff_preferences_cross_play = NA
+ else:
+ same_preferences_cross_play = NA
+ diff_preferences_cross_play = NA
+
+ all_perf = [
+ self_play,
+ cross_play,
+ same_preferences_cross_play,
+ diff_preferences_cross_play,
+ ]
+ return all_perf
+
+
+def _get_stats_for_file(file, n_players, file_data):
+ perf_per_mode = {}
+ file_path = os.path.expanduser(file)
+ with (open(file_path, "rb")) as f:
+ file_content = json.load(f)
+ for eval_mode, mode_perf in file_content.items():
+ perf = [None] * 2
+ print("eval_mode", eval_mode)
+ for metric, metric_perf in mode_perf.items():
+ player_idx = _extract_player_idx(metric)
+
+ perf_per_replicat = np.array(
+ _convert_str_of_list_to_list(metric_perf["raw_data"])
+ )
+
+ n_replicates_in_content = len(perf_per_replicat)
+ values_per_replicat_per_player = _adapt_values(
+ perf_per_replicat, file_data, player_idx
+ )
+
+ print(
+ "values_per_replicat_per_player",
+ values_per_replicat_per_player,
+ )
+ mean_per_player = values_per_replicat_per_player.mean(axis=0)
+ std_dev_per_player = values_per_replicat_per_player.std(axis=0)
+ std_err_per_player = std_dev_per_player / np.sqrt(
+ n_replicates_in_content
+ )
+ perf[player_idx] = Perf(
+ mean_per_player,
+ std_dev_per_player,
+ std_err_per_player,
+ values_per_replicat_per_player,
+ )
+ perf_per_mode[eval_mode] = perf
+
+ return perf_per_mode
+
+
+def _extract_player_idx(metric):
+ if "player_row" in metric:
+ player_idx = PLAYER_0
+ elif "player_col" in metric:
+ player_idx = PLAYER_1
+ elif "player_red" in metric:
+ player_idx = PLAYER_0
+ elif "player_blue" in metric:
+ player_idx = PLAYER_1
+ else:
+ raise ValueError()
+ return player_idx
+
+
+def _adapt_values(values_per_replicat_per_player, file_data, player_idx):
+ scaled_values = (
+ values_per_replicat_per_player / file_data.reward_adaptation_divider
+ )
+ normalized_values = scaled_values - file_data.min_r_by_players[player_idx]
+ normalized_values = normalized_values / (
+ file_data.max_r_by_players[player_idx]
+ - file_data.min_r_by_players[player_idx]
+ )
+ return normalized_values
+
+
+def _convert_str_of_list_to_list(str_of_list):
+ return [
+ float(v)
+ for v in str_of_list.replace("[", "")
+ .replace("]", "")
+ .replace(" ", "")
+ .split(",")
+ ]
+
+
+def _plot_bars(perf_per_mode_per_files):
+ plt.figure(figsize=(10, 5))
+
+ plt.subplot(121)
+ _, x, groups = _plot_merged_players(perf_per_mode_per_files, mcp=False)
+ plt.xticks(x, groups, rotation=15)
+ plt.ylabel("Normalized scores")
+
+ plt.subplot(122)
+ legend, x, groups = _plot_merged_players(perf_per_mode_per_files, mcp=True)
+ plt.xticks(x, groups, rotation=15)
+ plt.ylabel("Normalized scores")
+
+ plt.legend(
+ legend,
+ frameon=True,
+ bbox_to_anchor=(1.0, -0.23),
+ )
+
+ # Save the figure and show
+ plt.tight_layout(rect=[0, -0.05, 1.0, 1.0])
+ plt.savefig("bar_plot.png")
+ plt.show()
+
+
+def _extract_value(all_perf, idx, player_idx, attrib):
+ return [
+ getattr(el[idx][player_idx], attrib)
+ if hasattr(el[idx][player_idx], attrib)
+ else 0.0
+ for el in all_perf
+ ]
+
+
+def _plot_merged_players(perf_per_mode_per_files, mcp: bool):
+ all_perf = [el.perf for el in perf_per_mode_per_files]
+ groups = [f"{el.env}+{el.base_algo}" for el in perf_per_mode_per_files]
+
+ width = 0.1
+ x_delta = 0.3
+
+ self_play_p0 = _extract_value(all_perf, 0, PLAYER_0, "raw")
+ self_play_p1 = _extract_value(all_perf, 0, PLAYER_1, "raw")
+ cross_play_p0 = _extract_value(all_perf, 1, PLAYER_0, "raw")
+ cross_play_p1 = _extract_value(all_perf, 1, PLAYER_1, "raw")
+ same_pref_p0 = _extract_value(all_perf, 2, PLAYER_0, "raw")
+ same_pref_p1 = _extract_value(all_perf, 2, PLAYER_1, "raw")
+ diff_pref_p0 = _extract_value(all_perf, 3, PLAYER_0, "raw")
+ diff_pref_p1 = _extract_value(all_perf, 3, PLAYER_1, "raw")
+
+ self_play = _avg_over_players(self_play_p0, self_play_p1)
+ cross_play = _avg_over_players(cross_play_p0, cross_play_p1)
+ same_pref_perf = _avg_over_players(same_pref_p0, same_pref_p1)
+ diff_pref_perf = _avg_over_players(diff_pref_p0, diff_pref_p1)
+
+ self_play_err = _get_std_err(self_play)
+ cross_play_err = _get_std_err(cross_play)
+ same_pref_perf_err = _get_std_err(same_pref_perf)
+ diff_pref_perf_err = _get_std_err(diff_pref_perf)
+
+ self_play = _get_mean(self_play)
+ cross_play = _get_mean(cross_play)
+ same_pref_perf = _get_mean(same_pref_perf)
+ diff_pref_perf = _get_mean(diff_pref_perf)
+
+ # We can't infer preferences (or pseudo-welfare function) in LOLA-Exact in
+ # IPD since only one equilibrium is generated. Thus removing these badly
+ # named data points
+ # same_pref_perf[LOLA_EXACT_WT_IPD_IDX] = 0.0
+ # diff_pref_perf[LOLA_EXACT_WT_IPD_IDX] = 0.0
+
+ if not mcp:
+ plt.text(1.33, 0.04, NA, fontdict={"fontsize": 10.0, "rotation": 90})
+ plt.text(1.43, 0.04, NA, fontdict={"fontsize": 10.0, "rotation": 90})
+ self_play = self_play[:N_NO_MCP]
+ cross_play = cross_play[:N_NO_MCP]
+ same_pref_perf = same_pref_perf[:N_NO_MCP]
+ diff_pref_perf = diff_pref_perf[:N_NO_MCP]
+ self_play_err = self_play_err[:N_NO_MCP]
+ cross_play_err = cross_play_err[:N_NO_MCP]
+ same_pref_perf_err = same_pref_perf_err[:N_NO_MCP]
+ diff_pref_perf_err = diff_pref_perf_err[:N_NO_MCP]
+ groups = groups[:N_NO_MCP]
+ subplot_name = "Not MCPs"
+ else:
+ self_play = self_play[N_NO_MCP:]
+ cross_play = cross_play[N_NO_MCP:]
+ same_pref_perf = same_pref_perf[N_NO_MCP:]
+ diff_pref_perf = diff_pref_perf[N_NO_MCP:]
+ self_play_err = self_play_err[N_NO_MCP:]
+ cross_play_err = cross_play_err[N_NO_MCP:]
+ same_pref_perf_err = same_pref_perf_err[N_NO_MCP:]
+ diff_pref_perf_err = diff_pref_perf_err[N_NO_MCP:]
+ groups = groups[N_NO_MCP:]
+ subplot_name = "MCPs"
+ plt.text(1.0, 1.1, subplot_name, fontdict={"fontsize": 14.0})
+
+ x = np.arange(len(self_play))
+
+ plt.bar(
+ x_delta + x - width * 1.5,
+ self_play,
+ width,
+ yerr=self_play_err,
+ )
+ plt.bar(
+ x_delta + x - width * 0.5,
+ cross_play,
+ width,
+ yerr=cross_play_err,
+ )
+ plt.bar(
+ x_delta + x + width * 0.5,
+ same_pref_perf,
+ width,
+ yerr=same_pref_perf_err,
+ )
+ plt.bar(
+ x_delta + x + width * 1.5,
+ diff_pref_perf,
+ width,
+ yerr=diff_pref_perf_err,
+ )
+ legend = [
+ "Self-play",
+ "Cross-play",
+ "Cross-play between identical preferences",
+ "Cross-play between different preferences",
+ ]
+
+ return legend, x, groups
+
+
+def _avg_over_players(values_player0, values_player1):
+ return [
+ (np.array(v_p0) + np.array(v_p1)) / 2
+ for v_p0, v_p1 in zip(values_player0, values_player1)
+ ]
+
+
+def _get_std_err(values):
+ return [
+ v.std() / np.sqrt(v.shape[0]) if len(v.shape) > 0 else 0.0
+ for v in values
+ ]
+
+
+def _get_mean(values):
+ return [v.mean() if len(v.shape) > 0 else 0.0 for v in values]
+
+
+if __name__ == "__main__":
+ debug_mode = False
+ main(debug_mode)
diff --git a/marltoolbox/utils/callbacks.py b/marltoolbox/utils/callbacks.py
index 1b5ad23..c66e84f 100644
--- a/marltoolbox/utils/callbacks.py
+++ b/marltoolbox/utils/callbacks.py
@@ -6,6 +6,7 @@
from ray.rllib.utils.typing import AgentID, PolicyID
from marltoolbox.utils.miscellaneous import logger
+from marltoolbox.utils import restore
if TYPE_CHECKING:
from ray.rllib.evaluation import RolloutWorker
diff --git a/marltoolbox/utils/config_helper.py b/marltoolbox/utils/config_helper.py
new file mode 100644
index 0000000..39fd01f
--- /dev/null
+++ b/marltoolbox/utils/config_helper.py
@@ -0,0 +1,68 @@
+from collections.abc import Callable
+
+from ray import tune
+from ray.rllib.utils import PiecewiseSchedule
+
+from marltoolbox.utils.miscellaneous import move_to_key
+
+
+def get_temp_scheduler() -> Callable:
+ """
+ Use the hyperparameter 'temperature_steps_config' stored inside the
+ env_config dict since there are no control done over the keys of this
+ dictionary.
+ This hyperparameter is a list of tuples List[Tuple]. Each tuple define
+ one step in the scheduler.
+
+ :return: an easily customizable temperature scheduler
+ """
+ return configurable_linear_scheduler("temperature_steps_config")
+
+
+def get_lr_scheduler() -> Callable:
+ """
+ Use the hyperparameter 'lr_steps_config' stored inside the
+ env_config dict since there are no control done over the keys of this
+ dictionary.
+ This hyperparameter is a list of tuples List[Tuple]. Each tuple define
+ one step in the scheduler.
+
+ :return: an easily customizable temperature scheduler
+ """
+ return configurable_linear_scheduler(
+ "lr_steps_config", second_term_key="lr"
+ )
+
+
+def configurable_linear_scheduler(config_key, second_term_key: str = None):
+ """Returns a configurable linear scheduler which use the hyperparameters
+ stop.episodes_total and config.env_config.max_steps fro the RLLib
+ config."""
+
+ return tune.sample_from(
+ lambda spec: PiecewiseSchedule(
+ endpoints=[
+ (
+ int(
+ spec.config["env_config"]["max_steps"]
+ * spec.stop["episodes_total"]
+ * step_config[0]
+ ),
+ step_config[1],
+ )
+ if second_term_key is None
+ else (
+ int(
+ spec.config["env_config"]["max_steps"]
+ * spec.stop["episodes_total"]
+ * step_config[0]
+ ),
+ step_config[1]
+ * move_to_key(spec.config, second_term_key)[2],
+ )
+ for step_config in spec.config["env_config"][config_key]
+ ],
+ outside_value=spec.config["env_config"][config_key][-1][1],
+ framework="torch",
+ )
+ )
diff --git a/marltoolbox/utils/cross_play/__init__.py b/marltoolbox/utils/cross_play/__init__.py
new file mode 100644
index 0000000..0db6d6f
--- /dev/null
+++ b/marltoolbox/utils/cross_play/__init__.py
@@ -0,0 +1,2 @@
+import marltoolbox.utils.cross_play.utils
+import marltoolbox.utils.cross_play.evaluator
diff --git a/marltoolbox/utils/self_and_cross_perf.py b/marltoolbox/utils/cross_play/evaluator.py
similarity index 69%
rename from marltoolbox/utils/self_and_cross_perf.py
rename to marltoolbox/utils/cross_play/evaluator.py
index b710a70..c0048a8 100644
--- a/marltoolbox/utils/self_and_cross_perf.py
+++ b/marltoolbox/utils/cross_play/evaluator.py
@@ -4,24 +4,19 @@
import os
import pickle
import random
-import warnings
from typing import Dict
-import matplotlib.pyplot as plt
-import numpy as np
-import pandas as pd
-
-plt.style.use("seaborn-whitegrid")
-
import ray
from ray import tune
from ray.rllib.agents.pg import PGTrainer
from ray.tune.analysis import ExperimentAnalysis
from ray.tune.integration.wandb import WandbLogger
from ray.tune.logger import DEFAULT_LOGGERS
+from ray.tune.logger import SafeFallbackEncoder
+from marltoolbox import utils
from marltoolbox.utils import restore, log, miscellaneous
-from marltoolbox.utils.plot import PlotHelper, PlotConfig
+from marltoolbox.utils.cross_play import ploter
logger = logging.getLogger(__name__)
@@ -84,8 +79,8 @@ def perform_evaluation_or_load_data(
stop_config,
policies_to_load_from_checkpoint,
tune_analysis_per_exp: list,
- TrainerClass=PGTrainer,
- TuneTrainerClass=None,
+ rllib_trainer_class=PGTrainer,
+ tune_trainer_class=None,
n_cross_play_per_checkpoint: int = 1,
n_self_play_per_checkpoint: int = 1,
to_load_path: str = None,
@@ -101,9 +96,9 @@ def perform_evaluation_or_load_data(
:param tune_analysis_per_exp: List of the tune_analysis you want to
extract the groups of checkpoints from. All the checkpoints in these
tune_analysis will be extracted.
- :param TrainerClass: (default is the PGTrainer class) Normal 1st argument (run_or_experiment) provided to
+ :param rllib_trainer_class: (default is the PGTrainer class) Normal 1st argument (run_or_experiment) provided to
tune.run(). You should use the one which provides the data flow you need. (Probably a simple PGTrainer will do).
- :param TuneTrainerClass: Will only be needed when you are going to evaluate policies created from a Tune
+ :param tune_trainer_class: Will only be needed when you are going to evaluate policies created from a Tune
trainer. You need to provide the class of this trainer.
:param n_cross_play_per_checkpoint: (int) How many cross-play experiment per checkpoint you want to run.
They are run randomly against the other checkpoints.
@@ -114,8 +109,8 @@ def perform_evaluation_or_load_data(
"""
if to_load_path is None:
self.define_the_experiment_to_run(
- TrainerClass=TrainerClass,
- TuneTrainerClass=TuneTrainerClass,
+ TrainerClass=rllib_trainer_class,
+ TuneTrainerClass=tune_trainer_class,
evaluation_config=evaluation_config,
stop_config=stop_config,
policies_to_load_from_checkpoint=policies_to_load_from_checkpoint,
@@ -139,7 +134,7 @@ def define_the_experiment_to_run(
stop_config: dict,
TuneTrainerClass=None,
TrainerClass=PGTrainer,
- policies_to_load_from_checkpoint: list = ["All"],
+ policies_to_load_from_checkpoint: list = ("All"),
):
"""
:param evaluation_config: Normal config argument provided to tune.run().
@@ -148,8 +143,9 @@ def define_the_experiment_to_run(
:param stop_config: Normal stop_config argument provided to tune.run().
:param TuneTrainerClass: Will only be needed when you are going to evaluate policies created from a Tune
trainer. You need to provide the class of this trainer.
- :param TrainerClass: (default is the PGTrainer class) Normal 1st argument (run_or_experiment) provided to
- tune.run(). You should use the one which provides the data flow you need. (Probably a simple PGTrainer will do).
+ :param TrainerClass: (default is the PGTrainer class) The usual 1st
+ argument provided to tune.run(). You should use the one which
+ provides the data flow you need. (Probably a simple PGTrainer will do).
:param policies_to_load_from_checkpoint:
"""
@@ -165,14 +161,16 @@ def define_the_experiment_to_run(
self.policies_ids_sorted = sorted(
list(self.evaluation_config["multiagent"]["policies"].keys())
)
- self.policies_to_load_from_checkpoint = sorted(
- [
- policy_id
- for policy_id in self.policies_ids_sorted
- if self._is_policy_to_load(
- policy_id, policies_to_load_from_checkpoint
- )
- ]
+ self.policies_to_load_from_checkpoint = tuple(
+ sorted(
+ [
+ policy_id
+ for policy_id in self.policies_ids_sorted
+ if self._is_policy_to_load(
+ policy_id, policies_to_load_from_checkpoint
+ )
+ ]
+ )
)
self.experiment_defined = True
@@ -218,8 +216,10 @@ def _extract_groups_of_checkpoints(
def _extract_one_group_of_checkpoints(
self, one_tune_result: ExperimentAnalysis, group_name
):
- checkpoints_in_one_group = miscellaneous.extract_checkpoints(
- one_tune_result
+ checkpoints_in_one_group = (
+ utils.restore.extract_checkpoints_from_tune_analysis(
+ one_tune_result
+ )
)
self.checkpoints.extend(
[
@@ -370,7 +370,7 @@ def _save_results_as_json(self, available_metrics_list):
save_path = self.save_path.split(".")[0:-1] + ["json"]
save_path = ".".join(save_path)
with open(save_path, "w") as outfile:
- json.dump(metrics, outfile)
+ json.dump(metrics, outfile, cls=SafeFallbackEncoder)
def load_results(self, to_load_path):
assert to_load_path.endswith(self.results_file_name), (
@@ -443,23 +443,6 @@ def _select_opponent_randomly(
)
return opponents
- def _split_results_per_mode_and_group_pair_id(
- self, all_metadata_wt_results
- ):
- analysis_per_mode = []
-
- metadata_per_modes = self._split_metadata_per_mode(
- all_metadata_wt_results
- )
- for mode, metadata_for_one_mode in metadata_per_modes.items():
- analysis_per_mode.extend(
- self._split_metadata_per_group_pair_id(
- metadata_for_one_mode, mode
- )
- )
-
- return analysis_per_mode
-
def _split_metadata_per_mode(self, all_results):
return {
mode: [report for report in all_results if report["mode"] == mode]
@@ -472,33 +455,63 @@ def _split_metadata_per_group_pair_id(self, metadata_for_one_mode, mode):
tune_analysis = [
metadata["results"] for metadata in metadata_for_one_mode
]
- group_pair_names = [
+ pairs_of_group_names = [
self._get_pair_of_group_names(metadata)
for metadata in metadata_for_one_mode
]
- group_pair_ids = [
- self._get_id_of_pair_of_group_names(one_pair_of_names)
- for one_pair_of_names in group_pair_names
+ ids_of_pairs_of_groups = [
+ self._get_id_of_pair_of_group_names(one_pair_of_group_names)
+ for one_pair_of_group_names in pairs_of_group_names
]
- group_pair_ids_in_this_mode = sorted(set(group_pair_ids))
+ group_pair_ids_in_this_mode = sorted(set(ids_of_pairs_of_groups))
for group_pair_id in list(group_pair_ids_in_this_mode):
(
filtered_analysis_list,
- one_pair_of_names,
+ one_pair_of_group_names,
) = self._find_and_group_results_for_one_group_pair_id(
- group_pair_id, tune_analysis, group_pair_ids, group_pair_names
+ group_pair_id,
+ tune_analysis,
+ ids_of_pairs_of_groups,
+ pairs_of_group_names,
)
analysis_per_group_pair_id.append(
(
mode,
filtered_analysis_list,
group_pair_id,
- one_pair_of_names,
+ one_pair_of_group_names,
)
)
return analysis_per_group_pair_id
+ def _get_pair_of_group_names(self, metadata):
+ # checkpoints_idx_used = [
+ # metadata[policy_id]["checkpoint_i"]
+ # for policy_id in self.policies_to_load_from_checkpoint
+ # ]
+ # pair_of_group_names = [
+ # self.checkpoints[checkpoint_i]["group_name"]
+ # for checkpoint_i in checkpoints_idx_used
+ # ]
+ checkpoints_idx_used = {
+ policy_id: metadata[policy_id]["checkpoint_i"]
+ for policy_id in self.policies_to_load_from_checkpoint
+ }
+ pair_of_group_names = {
+ policy_id: self.checkpoints[checkpoint_i]["group_name"]
+ for policy_id, checkpoint_i in checkpoints_idx_used.items()
+ }
+ return pair_of_group_names
+
+ def _get_id_of_pair_of_group_names(self, pair_of_group_names):
+ ordered_pair_of_group_names = [
+ pair_of_group_names[policy_id]
+ for policy_id in self.policies_to_load_from_checkpoint
+ ]
+ id_of_pair_of_group_names = "".join(ordered_pair_of_group_names)
+ return id_of_pair_of_group_names
+
def _find_and_group_results_for_one_group_pair_id(
self, group_pair_id, tune_analysis, group_pair_ids, group_pair_names
):
@@ -519,20 +532,6 @@ def _find_and_group_results_for_one_group_pair_id(
return filtered_tune_analysis, one_pair_of_names
- def _extract_all_metrics(self, analysis_per_mode):
- analysis_metrics_per_mode = []
- for mode_i, mode_data in enumerate(analysis_per_mode):
- mode, analysis_list, group_pair_id, group_pair_name = mode_data
-
- available_metrics_list = []
- for trial in analysis_list:
- available_metrics = trial.metric_analysis
- available_metrics_list.append(available_metrics)
- analysis_metrics_per_mode.append(
- (mode, available_metrics_list, group_pair_id, group_pair_name)
- )
- return analysis_metrics_per_mode
-
def _group_results_and_extract_metrics(self, all_metadata_wt_results):
# TODO improve the design to remove these unclear names
analysis_per_mode_per_group_pair_id = (
@@ -545,20 +544,36 @@ def _group_results_and_extract_metrics(self, all_metadata_wt_results):
)
return analysis_metrics_per_mode_per_group_pair_id
- def _get_id_of_pair_of_group_names(self, pair_of_group_names):
- id_of_pair_of_group_names = "".join(pair_of_group_names)
- return id_of_pair_of_group_names
+ def _split_results_per_mode_and_group_pair_id(
+ self, all_metadata_wt_results
+ ):
+ analysis_per_mode = []
- def _get_pair_of_group_names(self, metadata):
- checkpoints_idx_used = [
- metadata[policy_id]["checkpoint_i"]
- for policy_id in self.policies_to_load_from_checkpoint
- ]
- pair_of_group_names = [
- self.checkpoints[checkpoint_i]["group_name"]
- for checkpoint_i in checkpoints_idx_used
- ]
- return pair_of_group_names
+ metadata_per_modes = self._split_metadata_per_mode(
+ all_metadata_wt_results
+ )
+ for mode, metadata_for_one_mode in metadata_per_modes.items():
+ analysis_per_mode.extend(
+ self._split_metadata_per_group_pair_id(
+ metadata_for_one_mode, mode
+ )
+ )
+
+ return analysis_per_mode
+
+ def _extract_all_metrics(self, analysis_per_mode):
+ analysis_metrics_per_mode = []
+ for mode_i, mode_data in enumerate(analysis_per_mode):
+ mode, analysis_list, group_pair_id, group_pair_name = mode_data
+
+ available_metrics_list = []
+ for trial in analysis_list:
+ available_metrics = trial.metric_analysis
+ available_metrics_list.append(available_metrics)
+ analysis_metrics_per_mode.append(
+ (mode, available_metrics_list, group_pair_id, group_pair_name)
+ )
+ return analysis_metrics_per_mode
def plot_results(
self,
@@ -567,213 +582,181 @@ def plot_results(
x_axis_metric,
y_axis_metric,
):
- plotter = SelfAndCrossPlayPlotter()
- return plotter.plot_results(
- exp_parent_dir=self.exp_parent_dir,
- metrics_per_mode=analysis_metrics_per_mode,
- plot_config=plot_config,
- x_axis_metric=x_axis_metric,
- y_axis_metric=y_axis_metric,
+
+ vanilla_plot_path = self._plot_as_provided(
+ analysis_metrics_per_mode,
+ plot_config,
+ x_axis_metric,
+ y_axis_metric,
)
+ if plot_config.plot_max_n_points is not None:
+ plot_config.plot_max_n_points *= 2
-class SelfAndCrossPlayPlotter:
- def __init__(self):
- self.x_axis_metric = None
- self.y_axis_metric = None
- self.metric_mode = None
- self.stat_summary = None
- self.data_groups_per_mode = None
+ self._plot_merge_self_cross(
+ analysis_metrics_per_mode,
+ plot_config,
+ x_axis_metric,
+ y_axis_metric,
+ )
+ self._plot_merge_same_and_diff_pref(
+ analysis_metrics_per_mode,
+ plot_config,
+ x_axis_metric,
+ y_axis_metric,
+ )
+ return vanilla_plot_path
- def plot_results(
+ def _plot_as_provided(
self,
- exp_parent_dir: str,
- x_axis_metric: str,
- y_axis_metric: str,
- metrics_per_mode: list,
- plot_config: PlotConfig,
- metric_mode: str = "avg",
+ analysis_metrics_per_mode,
+ plot_config,
+ x_axis_metric,
+ y_axis_metric,
):
- self._reset(x_axis_metric, y_axis_metric, metric_mode)
- for metrics_for_one_evaluation_mode in metrics_per_mode:
- self._extract_performance_evaluation_points(
- metrics_for_one_evaluation_mode
- )
- self.stat_summary.save_summary(
- filename_prefix=RESULTS_SUMMARY_FILENAME_PREFIX,
- folder_dir=exp_parent_dir,
+ vanilla_plot_path = self._plot_one_time(
+ analysis_metrics_per_mode,
+ plot_config,
+ x_axis_metric,
+ y_axis_metric,
)
- return self._plot_and_save_fig(plot_config, exp_parent_dir)
+ return vanilla_plot_path
- def _reset(self, x_axis_metric, y_axis_metric, metric_mode):
- self.x_axis_metric = x_axis_metric
- self.y_axis_metric = y_axis_metric
- self.metric_mode = metric_mode
- self.stat_summary = StatisticSummary(
- self.x_axis_metric, self.y_axis_metric, self.metric_mode
+ def _plot_merge_self_cross(
+ self,
+ analysis_metrics_per_mode,
+ plot_config,
+ x_axis_metric,
+ y_axis_metric,
+ ):
+ self._plot_one_time_with_prefix_and_preprocess(
+ "_self_cross",
+ "_merge_into_self_and_cross",
+ analysis_metrics_per_mode,
+ plot_config,
+ x_axis_metric,
+ y_axis_metric,
)
- self.data_groups_per_mode = {}
- def _extract_performance_evaluation_points(
- self, metrics_for_one_evaluation_mode
+ def _plot_merge_same_and_diff_pref(
+ self,
+ analysis_metrics_per_mode,
+ plot_config,
+ x_axis_metric,
+ y_axis_metric,
):
- (
- mode,
- available_metrics_list,
- group_pair_id,
- group_pair_name,
- ) = metrics_for_one_evaluation_mode
-
- label = self._get_label(mode, group_pair_name)
- x, y = self._extract_x_y_points(available_metrics_list)
-
- self.stat_summary.aggregate_stats_on_data_points(x, y, label)
- self.data_groups_per_mode[label] = self._format_as_df(x, y)
- print("x, y", x, y)
-
- def _get_label(self, mode, group_pair_name):
- # For backward compatibility
- if mode == "Same-play" or mode == "same training run":
- mode = self.SELF_PLAY_MODE
- elif mode == "Cross-play" or mode == "cross training run":
- mode = self.CROSS_PLAY_MODE
-
- print("Evaluator mode:", mode)
- if self._suffix_needed(group_pair_name):
- print("Using group_pair_name:", group_pair_name)
- label = f"{mode}: " + " vs ".join(group_pair_name)
- else:
- label = mode
- label = label.replace("_", " ")
- print("label", label)
- return label
- def _suffix_needed(self, group_pair_name):
- return (
- group_pair_name is not None
- and all([name is not None for name in group_pair_name])
- and all(group_pair_name)
+ self._plot_one_time_with_prefix_and_preprocess(
+ "_same_and_diff_pref",
+ "_merge_into_cross_same_pref_diff_pref",
+ analysis_metrics_per_mode,
+ plot_config,
+ x_axis_metric,
+ y_axis_metric,
)
- def _extract_x_y_points(self, available_metrics_list):
- x, y = [], []
- assert len(available_metrics_list) > 0
- random.shuffle(available_metrics_list)
-
- for available_metrics in available_metrics_list:
- if self.x_axis_metric in available_metrics.keys():
- x_point = available_metrics[self.x_axis_metric][
- self.metric_mode
- ]
- else:
- x_point = 123456789
- warnings.warn(
- f"x_axis_metric {self.x_axis_metric}"
- " not in available_metrics "
- f"{available_metrics.keys()}"
- )
- if self.y_axis_metric in available_metrics.keys():
- y_point = available_metrics[self.y_axis_metric][
- self.metric_mode
- ]
+ def _plot_one_time_with_prefix_and_preprocess(
+ self,
+ prefix: str,
+ preprocess: str,
+ analysis_metrics_per_mode,
+ plot_config,
+ x_axis_metric,
+ y_axis_metric,
+ ):
+ initial_filename_prefix = plot_config.filename_prefix
+ plot_config.filename_prefix += prefix
+ metrics_for_same_pref_diff_pref = getattr(self, preprocess)(
+ analysis_metrics_per_mode
+ )
+ self._plot_one_time(
+ metrics_for_same_pref_diff_pref,
+ plot_config,
+ x_axis_metric,
+ y_axis_metric,
+ )
+ plot_config.filename_prefix = initial_filename_prefix
+
+ def _merge_into_self_and_cross(self, analysis_metrics_per_mode):
+ metrics_for_self_and_cross = []
+ for selected_play_mode in self.MODES:
+ metrics_for_self_and_cross.append(
+ (selected_play_mode, [], "", None)
+ )
+ for (
+ play_mode,
+ metrics,
+ pair_name,
+ pair_tuple,
+ ) in analysis_metrics_per_mode:
+ if play_mode == selected_play_mode:
+ metrics_for_self_and_cross[-1][1].extend(metrics)
+ return metrics_for_self_and_cross
+
+ def _merge_into_cross_same_pref_diff_pref(self, analysis_metrics_per_mode):
+ analysis_metrics_per_mode_wtout_self_play = self._copy_wtout_self_play(
+ analysis_metrics_per_mode
+ )
+ one_pair_of_group_names = analysis_metrics_per_mode[0][3]
+ metrics_for_same_pref_diff_pref = [
+ (
+ self.CROSS_PLAY_MODE,
+ [],
+ "same_prefsame_pref",
+ {k: "same_pref" for k in one_pair_of_group_names.keys()},
+ ),
+ (
+ self.CROSS_PLAY_MODE,
+ [],
+ "diff_prefdiff_pref",
+ {k: "diff_pref" for k in one_pair_of_group_names.keys()},
+ ),
+ ]
+ for (
+ play_mode,
+ metrics,
+ pair_name,
+ pair_of_group_names,
+ ) in analysis_metrics_per_mode_wtout_self_play:
+ groups_names = list(pair_of_group_names.values())
+ assert len(groups_names) == 2
+ if groups_names[0] == groups_names[1]:
+ metrics_for_same_pref_diff_pref[0][1].extend(metrics)
else:
- y_point = 123456789
- warnings.warn(
- f"y_axis_metric {self.y_axis_metric}"
- " not in available_metrics "
- f"{available_metrics.keys()}"
+ metrics_for_same_pref_diff_pref[1][1].extend(metrics)
+ if len(metrics_for_same_pref_diff_pref[1][1]) == 0:
+ metrics_for_same_pref_diff_pref.pop(1)
+ if len(metrics_for_same_pref_diff_pref[0][1]) == 0:
+ metrics_for_same_pref_diff_pref.pop(0)
+ return metrics_for_same_pref_diff_pref
+
+ def _copy_wtout_self_play(self, analysis_metrics_per_mode):
+ analysis_metrics_per_mode_wtout_self_play = []
+ for (
+ play_mode,
+ metrics,
+ pair_name,
+ pair_tuple,
+ ) in analysis_metrics_per_mode:
+ if play_mode == self.CROSS_PLAY_MODE:
+ analysis_metrics_per_mode_wtout_self_play.append(
+ (self.CROSS_PLAY_MODE, metrics, pair_name, pair_tuple)
)
- x.append(x_point)
- y.append(y_point)
- return x, y
-
- def _format_as_df(self, x, y):
- group_df_dict = {
- "": [
- (one_x_point, one_y_point)
- for one_x_point, one_y_point in zip(x, y)
- ]
- }
- group_df = pd.DataFrame(group_df_dict)
- return group_df
-
- def _plot_and_save_fig(self, plot_config, exp_parent_dir):
- plot_helper = PlotHelper(plot_config)
- plot_helper.plot_cfg.save_dir_path = exp_parent_dir
- return plot_helper.plot_dots(self.data_groups_per_mode)
-
-
-class StatisticSummary:
- def __init__(self, x_axis_metric, y_axis_metric, metric_mode):
- self.x_means, self.x_se, self.x_labels, self.x_raw = [], [], [], []
- self.y_means, self.y_se, self.y_labels, self.y_raw = [], [], [], []
- self.matrix_label = []
- self.x_axis_metric, self.y_axis_metric = x_axis_metric, y_axis_metric
- self.metric_mode = metric_mode
-
- def aggregate_stats_on_data_points(self, x, y, label):
- # TODO refactor that to use a data structure
- # (like per metric and per plot?)
- self.x_means.append(sum(x) / len(x))
- self.x_se.append(np.array(x).std() / np.sqrt(len(x)))
- self.x_labels.append(
- f"Metric:{self.x_axis_metric}, " f"Metric mode:{self.metric_mode}"
- )
-
- self.y_means.append(sum(y) / len(y))
- self.y_se.append(np.array(y).std() / np.sqrt(len(y)))
- self.y_labels.append(
- f"Metric:{self.y_axis_metric}, " f"Metric mode:{self.metric_mode}"
- )
-
- self.matrix_label.append(label)
- self.x_raw.append(x)
- self.y_raw.append(y)
-
- def save_summary(self, filename_prefix, folder_dir):
- file_name = (
- f"{filename_prefix}_{self.y_axis_metric}_"
- f"vs_{self.x_axis_metric}_matrix.json"
- )
- file_name = file_name.replace("/", "_")
- file_path = os.path.join(folder_dir, file_name)
- formated_data = {}
- for step_i in range(len(self.x_means)):
- (
- x_mean,
- x_std_err,
- x_lbl,
- y_mean,
- y_std_err,
- y_lbl,
- lbl,
- x,
- y,
- ) = self._get_values_from_a_data_point(step_i)
- formated_data[lbl] = {
- x_lbl: {
- "mean": x_mean,
- "std_err": x_std_err,
- "raw_data": str(x),
- },
- y_lbl: {
- "mean": y_mean,
- "std_err": y_std_err,
- "raw_data": str(y),
- },
- }
- with open(file_path, "w") as f:
- json.dump(formated_data, f, indent=4, sort_keys=True)
+ return analysis_metrics_per_mode_wtout_self_play
- def _get_values_from_a_data_point(self, step_i):
- return (
- self.x_means[step_i],
- self.x_se[step_i],
- self.x_labels[step_i],
- self.y_means[step_i],
- self.y_se[step_i],
- self.y_labels[step_i],
- self.matrix_label[step_i],
- self.x_raw[step_i],
- self.y_raw[step_i],
+ def _plot_one_time(
+ self,
+ analysis_metrics_per_mode,
+ plot_config,
+ x_axis_metric,
+ y_axis_metric,
+ ):
+ plotter = ploter.SelfAndCrossPlayPlotter()
+ plot_path = plotter.plot_results(
+ exp_parent_dir=self.exp_parent_dir,
+ metrics_per_mode=analysis_metrics_per_mode,
+ plot_config=plot_config,
+ x_axis_metric=x_axis_metric,
+ y_axis_metric=y_axis_metric,
)
+ return plot_path
diff --git a/marltoolbox/utils/cross_play/ploter.py b/marltoolbox/utils/cross_play/ploter.py
new file mode 100644
index 0000000..a6a218b
--- /dev/null
+++ b/marltoolbox/utils/cross_play/ploter.py
@@ -0,0 +1,168 @@
+import logging
+import random
+import copy
+import pandas as pd
+
+from marltoolbox.utils.cross_play import evaluator
+from marltoolbox.utils.cross_play.stats_summary import StatisticSummary
+from marltoolbox.utils.plot import PlotHelper, PlotConfig
+
+logger = logging.getLogger(__name__)
+
+
+class SelfAndCrossPlayPlotter:
+ def __init__(self):
+ self.x_axis_metric = None
+ self.y_axis_metric = None
+ self.metric_mode = None
+ self.stat_summary = None
+ self.data_groups_per_mode = None
+
+ def plot_results(
+ self,
+ exp_parent_dir: str,
+ x_axis_metric: str,
+ y_axis_metric: str,
+ metrics_per_mode: list,
+ plot_config: PlotConfig,
+ metric_mode: str = "avg",
+ ):
+ self._reset(x_axis_metric, y_axis_metric, metric_mode)
+ for metrics_for_one_evaluation_mode in metrics_per_mode:
+ self._extract_performance_evaluation_points(
+ metrics_for_one_evaluation_mode
+ )
+ stat_summary_filename_prefix = (
+ plot_config.filename_prefix
+ + evaluator.RESULTS_SUMMARY_FILENAME_PREFIX
+ )
+ self.stat_summary.save_summary(
+ filename_prefix=stat_summary_filename_prefix,
+ folder_dir=exp_parent_dir,
+ )
+ return self._plot_and_save_fig(plot_config, exp_parent_dir)
+
+ def _reset(self, x_axis_metric, y_axis_metric, metric_mode):
+ self.x_axis_metric = x_axis_metric
+ self.y_axis_metric = y_axis_metric
+ self.metric_mode = metric_mode
+ self.stat_summary = StatisticSummary(
+ self.x_axis_metric, self.y_axis_metric, self.metric_mode
+ )
+ self.data_groups_per_mode = {}
+
+ def _extract_performance_evaluation_points(
+ self, metrics_for_one_evaluation_mode
+ ):
+ (
+ mode,
+ available_metrics_list,
+ group_pair_id,
+ group_pair_name,
+ ) = metrics_for_one_evaluation_mode
+
+ label = self._get_label(mode, group_pair_name)
+ x, y = self._extract_x_y_points(available_metrics_list)
+
+ self.stat_summary.aggregate_stats_on_data_points(x, y, label)
+ self.data_groups_per_mode[label] = self._format_as_df(x, y)
+ print("x, y", x, y)
+
+ def _get_label(self, mode, group_pair_name):
+
+ print("Evaluator mode:", mode)
+ if self._suffix_needed(group_pair_name):
+ ordered_group_pair_name = self._order_group_names(group_pair_name)
+ print(
+ "Using ordered_group_pair_name:",
+ ordered_group_pair_name,
+ "from group_pair_name:",
+ group_pair_name,
+ )
+ label = f"{mode}: " + " vs ".join(ordered_group_pair_name)
+ else:
+ label = mode
+ label = label.replace("_", " ")
+ print("label", label)
+ return label
+
+ def _suffix_needed(self, group_pair_name):
+ if group_pair_name is None:
+ return False
+ return all(
+ [name is not None for name in group_pair_name.values()]
+ ) and all(group_pair_name.values())
+
+ def _order_group_names(self, group_pair_name_original):
+ group_pair_name = copy.deepcopy(group_pair_name_original)
+ ordered_group_pair_name = []
+ for metric in (self.x_axis_metric, self.y_axis_metric):
+ for policy_id, one_group_name in group_pair_name.items():
+ print(
+ "_order_group_names policy_id in metric", policy_id, metric
+ )
+ if policy_id in metric:
+ ordered_group_pair_name.append(one_group_name)
+ group_pair_name.pop(policy_id)
+ break
+ assert len(group_pair_name.keys()) == 0, (
+ "group_pair_name_original.keys() "
+ f"{group_pair_name_original.keys()} not in the metrics provided: "
+ "(self.x_axis_metric, self.y_axis_metric) "
+ f"{(self.x_axis_metric, self.y_axis_metric)}"
+ )
+ return ordered_group_pair_name
+
+ def _extract_x_y_points(self, available_metrics_list):
+ x, y = [], []
+ assert len(available_metrics_list) > 0
+ random.shuffle(available_metrics_list)
+
+ for available_metrics in available_metrics_list:
+ if self.x_axis_metric in available_metrics.keys():
+ x_point = available_metrics[self.x_axis_metric][
+ self.metric_mode
+ ]
+ else:
+ x_point = 123456789
+ from ray.util.debug import log_once
+
+ msg = (
+ f"x_axis_metric {self.x_axis_metric}"
+ " not in available_metrics "
+ f"{available_metrics.keys()}"
+ )
+ if log_once(msg):
+ logger.warning(msg)
+
+ if self.y_axis_metric in available_metrics.keys():
+ y_point = available_metrics[self.y_axis_metric][
+ self.metric_mode
+ ]
+ else:
+ y_point = 123456789
+ msg = (
+ f"y_axis_metric {self.y_axis_metric}"
+ " not in available_metrics "
+ f"{available_metrics.keys()}"
+ )
+ if log_once(msg):
+ logger.warning(msg)
+ x.append(x_point)
+ y.append(y_point)
+ return x, y
+
+ def _format_as_df(self, x, y):
+ group_df_dict = {
+ "": [
+ (one_x_point, one_y_point)
+ for one_x_point, one_y_point in zip(x, y)
+ ]
+ }
+ group_df = pd.DataFrame(group_df_dict)
+ return group_df
+
+ def _plot_and_save_fig(self, plot_config, exp_parent_dir):
+ plot_helper = PlotHelper(plot_config)
+ plot_helper.plot_cfg.save_dir_path = exp_parent_dir
+ return plot_helper.plot_dots(self.data_groups_per_mode)
diff --git a/marltoolbox/utils/cross_play/stats_summary.py b/marltoolbox/utils/cross_play/stats_summary.py
new file mode 100644
index 0000000..a24dafa
--- /dev/null
+++ b/marltoolbox/utils/cross_play/stats_summary.py
@@ -0,0 +1,80 @@
+import json
+import os
+
+import numpy as np
+
+
+class StatisticSummary:
+ def __init__(self, x_axis_metric, y_axis_metric, metric_mode):
+ self.x_means, self.x_se, self.x_labels, self.x_raw = [], [], [], []
+ self.y_means, self.y_se, self.y_labels, self.y_raw = [], [], [], []
+ self.matrix_label = []
+ self.x_axis_metric, self.y_axis_metric = x_axis_metric, y_axis_metric
+ self.metric_mode = metric_mode
+
+ def aggregate_stats_on_data_points(self, x, y, label):
+ # TODO refactor that to use a data structure
+ # (like per metric and per plot?)
+ self.x_means.append(sum(x) / len(x))
+ self.x_se.append(np.array(x).std() / np.sqrt(len(x)))
+ self.x_labels.append(
+ f"Metric:{self.x_axis_metric}, " f"Metric mode:{self.metric_mode}"
+ )
+
+ self.y_means.append(sum(y) / len(y))
+ self.y_se.append(np.array(y).std() / np.sqrt(len(y)))
+ self.y_labels.append(
+ f"Metric:{self.y_axis_metric}, " f"Metric mode:{self.metric_mode}"
+ )
+
+ self.matrix_label.append(label)
+ self.x_raw.append(x)
+ self.y_raw.append(y)
+
+ def save_summary(self, filename_prefix, folder_dir):
+ file_name = (
+ f"{filename_prefix}_{self.y_axis_metric}_"
+ f"vs_{self.x_axis_metric}_matrix.json"
+ )
+ file_name = file_name.replace("/", "_")
+ file_path = os.path.join(folder_dir, file_name)
+ formated_data = {}
+ for step_i in range(len(self.x_means)):
+ (
+ x_mean,
+ x_std_err,
+ x_lbl,
+ y_mean,
+ y_std_err,
+ y_lbl,
+ lbl,
+ x,
+ y,
+ ) = self._get_values_from_a_data_point(step_i)
+ formated_data[lbl] = {
+ x_lbl: {
+ "mean": x_mean,
+ "std_err": x_std_err,
+ "raw_data": str(x),
+ },
+ y_lbl: {
+ "mean": y_mean,
+ "std_err": y_std_err,
+ "raw_data": str(y),
+ },
+ }
+ with open(file_path, "w") as f:
+ json.dump(formated_data, f, indent=4, sort_keys=True)
+
+ def _get_values_from_a_data_point(self, step_i):
+ return (
+ self.x_means[step_i],
+ self.x_se[step_i],
+ self.x_labels[step_i],
+ self.y_means[step_i],
+ self.y_se[step_i],
+ self.y_labels[step_i],
+ self.matrix_label[step_i],
+ self.x_raw[step_i],
+ self.y_raw[step_i],
+ )
diff --git a/marltoolbox/utils/cross_play/utils.py b/marltoolbox/utils/cross_play/utils.py
new file mode 100644
index 0000000..0bf6d55
--- /dev/null
+++ b/marltoolbox/utils/cross_play/utils.py
@@ -0,0 +1,112 @@
+import copy
+import random
+from typing import Dict, List
+
+from ray import tune
+
+
+def mix_policies_in_given_rllib_configs(
+ all_rllib_configs: List[Dict], n_mix_per_config: int
+) -> dict:
+ """
+ Mix the policies of a list of RLLib config dictionaries. Limited to
+ RLLib config with 2 policies. (Not used by the SelfAndCrossPlayEvaluator)
+
+
+ :param all_rllib_configs: all rllib config
+ :param n_mix_per_config: number of mix to create for each rllib config
+ provided
+ :return: a single rllib config with a grid search over all the mixed
+ pair of policies
+ """
+ assert (
+ n_mix_per_config <= len(all_rllib_configs) - 1
+ ), f" {n_mix_per_config} <= {len(all_rllib_configs) - 1}"
+ policy_ids = all_rllib_configs[0]["multiagent"]["policies"].keys()
+ assert len(policy_ids) == 2, (
+ "only supporting config dict with 2 RLLib " "policies"
+ )
+ _assert_all_config_use_the_same_policies(all_rllib_configs, policy_ids)
+
+ policy_config_variants = _gather_policy_variant_per_policy_id(
+ all_rllib_configs, policy_ids
+ )
+
+ master_config = _create_one_master_config(
+ all_rllib_configs, policy_config_variants, policy_ids, n_mix_per_config
+ )
+ return master_config
+
+
+def _create_one_master_config(
+ all_rllib_configs, policy_config_variants, policy_ids, n_mix_per_config
+):
+ all_policy_mix = []
+ player_1, player_2 = policy_ids
+ for config_idx, p1_policy_config in enumerate(
+ policy_config_variants[player_1]
+ ):
+ policies_mixes = _produce_n_mix_with_player_2_policies(
+ policy_config_variants,
+ player_2,
+ config_idx,
+ n_mix_per_config,
+ player_1,
+ p1_policy_config,
+ )
+ all_policy_mix.extend(policies_mixes)
+
+ master_config = copy.deepcopy(all_rllib_configs[0])
+ print("len(all_policy_mix)", len(all_policy_mix))
+ master_config["multiagent"]["policies"] = tune.grid_search(all_policy_mix)
+ return master_config
+
+
+def _produce_n_mix_with_player_2_policies(
+ policy_config_variants,
+ player_2,
+ config_idx,
+ n_mix_per_config,
+ player_1,
+ p1_policy_config,
+):
+ p2_policy_configs_sampled = _get_p2_policies_samples_excluding_self(
+ policy_config_variants, player_2, config_idx, n_mix_per_config
+ )
+ policies_mixes = []
+ for p2_policy_config in p2_policy_configs_sampled:
+ policy_mix = {
+ player_1: p1_policy_config,
+ player_2: p2_policy_config,
+ }
+ policies_mixes.append(policy_mix)
+ return policies_mixes
+
+
+def _get_p2_policies_samples_excluding_self(
+ policy_config_variants, player_2, config_idx, n_mix_per_config
+):
+ p2_policy_config_variants = copy.deepcopy(policy_config_variants[player_2])
+ p2_policy_config_variants.pop(config_idx)
+ p2_policy_configs_sampled = random.sample(
+ p2_policy_config_variants, n_mix_per_config
+ )
+ return p2_policy_configs_sampled
+
+
+def _assert_all_config_use_the_same_policies(all_rllib_configs, policy_ids):
+ for rllib_config in all_rllib_configs:
+ assert rllib_config["multiagent"]["policies"].keys() == policy_ids
+
+
+def _gather_policy_variant_per_policy_id(all_rllib_configs, policy_ids):
+ policy_config_variants = {}
+ for policy_id in policy_ids:
+ policy_config_variants[policy_id] = []
+ for rllib_config in all_rllib_configs:
+ policy_config_variants[policy_id].append(
+ copy.deepcopy(
+ rllib_config["multiagent"]["policies"][policy_id]
+ )
+ )
+ return policy_config_variants
diff --git a/marltoolbox/utils/log/__init__.py b/marltoolbox/utils/log/__init__.py
new file mode 100644
index 0000000..111e73e
--- /dev/null
+++ b/marltoolbox/utils/log/__init__.py
@@ -0,0 +1,19 @@
+from marltoolbox.utils.log.log import *
+from marltoolbox.utils.log.log import log_learning_rate
+from marltoolbox.utils.log.full_epi_logger import FullEpisodeLogger
+from marltoolbox.utils.log.model_summarizer import ModelSummarizer
+
+__all__ = [
+ "log_learning_rate",
+ "FullEpisodeLogger",
+ "ModelSummarizer",
+ "log_learning_rate",
+ "pprint_saved_metrics",
+ "save_metrics",
+ "extract_all_metrics_from_results",
+ "log_in_current_day_dir",
+ "compute_entropy_from_raw_q_values",
+ "augment_stats_fn_wt_additionnal_logs",
+ "get_log_from_policy",
+ "add_entropy_to_log",
+]
diff --git a/marltoolbox/utils/full_epi_logger.py b/marltoolbox/utils/log/full_epi_logger.py
similarity index 56%
rename from marltoolbox/utils/full_epi_logger.py
rename to marltoolbox/utils/log/full_epi_logger.py
index 98017c2..32f957b 100644
--- a/marltoolbox/utils/full_epi_logger.py
+++ b/marltoolbox/utils/log/full_epi_logger.py
@@ -4,18 +4,32 @@
import numpy as np
from ray.rllib.evaluation import MultiAgentEpisode
-from ray.tune.logger import _SafeFallbackEncoder
+from ray.tune.logger import SafeFallbackEncoder
logger = logging.getLogger(__name__)
class FullEpisodeLogger:
-
- def __init__(self, logdir, log_interval, log_ful_epi_one_hot_obs):
+ """
+ Helper to log the entire history of one episode as txt
+ """
+
+ def __init__(
+ self, logdir: str, log_interval: int, convert_one_hot_obs_to_idx: bool
+ ):
+ """
+
+ :param logdir: dir where to save the log file with the full episode
+ :param log_interval: interval (in number of episode) between the log
+ of two episodes
+ :param convert_one_hot_obs_to_idx: bool flag to chose to convert the
+ observation to their idx
+ (indented to bue used when dealing with one hot observations)
+ """
self.log_interval = log_interval
- self.log_ful_epi_one_hot_obs = log_ful_epi_one_hot_obs
+ self.log_ful_epi_one_hot_obs = convert_one_hot_obs_to_idx
- file_path = os.path.join(logdir, f"full_episodes_logs.json")
+ file_path = os.path.join(logdir, "full_episodes_logs.json")
self.file_path = os.path.expanduser(file_path)
logger.info(f"FullEpisodeLogger: using as file_path: {self.file_path}")
@@ -23,7 +37,6 @@ def __init__(self, logdir, log_interval, log_ful_epi_one_hot_obs):
self.internal_episode_counter = -1
self.step_counter = 0
self.episode_finised = True
- self._first_fake_step_done = False
self.json_logger = JsonSimpleLogger(self.file_path)
@@ -45,7 +58,8 @@ def _init_logging_new_full_episode(self):
self._log_full_epi_tmp_data = {}
def on_episode_step(
- self, episode: MultiAgentEpisode = None, step_data: dict = None):
+ self, episode: MultiAgentEpisode = None, step_data: dict = None
+ ):
if not self._log_current_full_episode:
return None
@@ -56,22 +70,23 @@ def on_episode_step(
step_data = {}
for agent_id, policy in episode._policies.items():
- if self._first_fake_step_done:
- if agent_id in self._log_full_epi_tmp_data.keys():
- obs_before_act = self._log_full_epi_tmp_data[agent_id]
- else:
- obs_before_act = None
- action = episode.last_action_for(agent_id).tolist()
- epi = episode.episode_id
- rewards = episode._agent_reward_history[agent_id]
- reward = rewards[-1] if len(rewards) > 0 else None
- info = episode.last_info_for(agent_id)
- if hasattr(policy, "to_log"):
- info.update(policy.to_log)
- else:
- logger.info(f"policy {policy} doesn't have attrib "
- "to_log. hasattr(policy, 'to_log'): "
- f"{hasattr(policy, 'to_log')}")
+ if agent_id in self._log_full_epi_tmp_data.keys():
+ obs_before_act = self._log_full_epi_tmp_data[agent_id]
+ else:
+ obs_before_act = None
+ action = episode.last_action_for(agent_id).tolist()
+ epi = episode.episode_id
+ rewards = episode._agent_reward_history[agent_id]
+ reward = rewards[-1] if len(rewards) > 0 else None
+ info = episode.last_info_for(agent_id)
+ if hasattr(policy, "to_log"):
+ info.update(policy.to_log)
+ else:
+ logger.info(
+ f"policy {policy} doesn't have attrib "
+ "to_log. hasattr(policy, 'to_log'): "
+ f"{hasattr(policy, 'to_log')}"
+ )
# Episode provide the last action with the given last
# observation produced by this action. But we need the
# observation that cause the agent to play this action
@@ -79,40 +94,37 @@ def on_episode_step(
obs_after_act = episode.last_observation_for(agent_id)
self._log_full_epi_tmp_data[agent_id] = obs_after_act
- if self._first_fake_step_done:
- if self.log_ful_epi_one_hot_obs:
- obs_before_act = np.argwhere(obs_before_act)
- obs_after_act = np.argwhere(obs_after_act)
-
- step_data[agent_id] = {
- "obs_before_act": obs_before_act,
- "obs_after_act": obs_after_act,
- "action": action,
- "reward": reward,
- "info": info,
- "epi": epi}
-
- if self._first_fake_step_done:
- self.json_logger.write_json(step_data)
- self.json_logger.write("\n")
- self.step_counter += 1
- else:
- logger.info("FullEpisodeLogger: don't log first fake step")
- self._first_fake_step_done = True
+ if self.log_ful_epi_one_hot_obs:
+ obs_before_act = np.argwhere(obs_before_act)
+ obs_after_act = np.argwhere(obs_after_act)
+
+ step_data[agent_id] = {
+ "obs_before_act": obs_before_act,
+ "obs_after_act": obs_after_act,
+ "action": action,
+ "reward": reward,
+ "info": info,
+ "epi": epi,
+ }
+
+ self.json_logger.write_json(step_data)
+ self.json_logger.write("\n")
+ self.step_counter += 1
def on_episode_end(self, base_env=None):
if self._log_current_full_episode:
if base_env is not None:
env = base_env.get_unwrapped()[0]
if hasattr(env, "max_steps"):
- assert self.step_counter == env.max_steps, \
- "The number of steps written to full episode " \
- "log file must be equal to the number of step in an " \
- f"episode self.step_counter {self.step_counter} " \
- f"must equal env.max_steps {env.max_steps}. " \
- "Otherwise there are some issue with the " \
- "state of the callback object, maybe being used by " \
+ assert self.step_counter == env.max_steps, (
+ "The number of steps written to full episode "
+ "log file must be equal to the number of step in an "
+ f"episode self.step_counter {self.step_counter} "
+ f"must equal env.max_steps {env.max_steps}. "
+ "Otherwise there are some issue with the "
+ "state of the callback object, maybe being used by "
"several experiments at the same time."
+ )
self.json_logger.write_json(
{"status": f"end of episode {self.internal_episode_counter}"}
)
@@ -126,12 +138,19 @@ def on_episode_end(self, base_env=None):
class JsonSimpleLogger:
+ """
+ Simple logger in json format
+ """
def __init__(self, file_path):
+ """
+
+ :param file_path: file path to the file to save to
+ """
self.local_file = file_path
def write_json(self, json_data):
- json.dump(json_data, self, cls=_SafeFallbackEncoder)
+ json.dump(json_data, self, cls=SafeFallbackEncoder)
def write(self, b):
self.local_out.write(b)
diff --git a/marltoolbox/utils/log.py b/marltoolbox/utils/log/log.py
similarity index 68%
rename from marltoolbox/utils/log.py
rename to marltoolbox/utils/log/log.py
index d7f38eb..6625549 100644
--- a/marltoolbox/utils/log.py
+++ b/marltoolbox/utils/log/log.py
@@ -1,14 +1,12 @@
import copy
import datetime
import logging
-import math
import numbers
import os
import pickle
import pprint
import re
from collections import Iterable
-from typing import Dict, Callable, TYPE_CHECKING
import gym
import torch
@@ -17,12 +15,13 @@
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
-from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils.typing import PolicyID, TensorType
-from scipy.special import softmax
-from torch.nn import Module
+from ray.util.debug import log_once
+from torch.distributions import Categorical
+from typing import Dict, Callable, TYPE_CHECKING, Optional, List
-from marltoolbox.utils.full_epi_logger import FullEpisodeLogger
+from marltoolbox.utils.log.full_epi_logger import FullEpisodeLogger
+from marltoolbox.utils.log.model_summarizer import ModelSummarizer
if TYPE_CHECKING:
from ray.rllib.evaluation import RolloutWorker
@@ -33,6 +32,7 @@
def get_logging_callbacks_class(
log_env_step: bool = True,
log_from_policy: bool = True,
+ log_from_policy_in_evaluation: bool = False,
log_full_epi: bool = False,
log_full_epi_interval: int = 100,
log_ful_epi_one_hot_obs: bool = True,
@@ -49,7 +49,7 @@ def on_episode_start(
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode,
- env_index: int,
+ env_index: Optional[int] = None,
**kwargs,
):
if log_full_epi:
@@ -67,90 +67,13 @@ def _init_full_episode_logging(self, worker):
self._full_episode_logger = FullEpisodeLogger(
logdir=worker.io_context.log_dir,
log_interval=log_full_epi_interval,
- log_ful_epi_one_hot_obs=log_ful_epi_one_hot_obs,
+ convert_one_hot_obs_to_idx=log_ful_epi_one_hot_obs,
)
logger.info("_full_episode_logger init done")
def _log_model_sumamry(self, worker):
- if not hasattr(self, "_log_model_sumamry_done"):
- self._log_model_sumamry_done = True
- self._for_every_policy_print_model_stats(worker)
-
- def _for_every_policy_print_model_stats(self, worker):
- for policy_id, policy in worker.policy_map.items():
- msg = f"===== Models summaries policy_id {policy_id} ====="
- print(msg)
- logger.info(msg)
- self._print_model_summary(policy)
- self._count_parameters_in_every_modules(policy)
-
- @staticmethod
- def _print_model_summary(policy):
- if isinstance(policy, TorchPolicy):
- for k, v in policy.__dict__.items():
- if isinstance(v, Module):
- msg = f"{k}, {v}"
- print(msg)
- logger.info(msg)
-
- def _count_parameters_in_every_modules(self, policy):
- if isinstance(policy, TorchPolicy):
- for k, v in policy.__dict__.items():
- if isinstance(v, Module):
- self._count_and_log_for_one_module(policy, k, v)
-
- def _count_and_log_for_one_module(self, policy, module_name, module):
- n_param = self._count_parameters(module, module_name)
- n_param_shared_counted_once = self._count_parameters(
- module, module_name, count_shared_once=True
- )
- n_param_trainable = self._count_parameters(
- module, module_name, only_trainable=True
- )
- self._log_values_in_to_log(
- policy,
- {
- f"module_{module_name}_n_param": n_param,
- f"module_{module_name}_n_param_shared_counted_once": n_param_shared_counted_once,
- f"module_{module_name}_n_param_trainable": n_param_trainable,
- },
- )
-
- @staticmethod
- def _log_values_in_to_log(policy, dictionary):
- if hasattr(policy, "to_log"):
- policy.to_log.update(dictionary)
-
- @staticmethod
- def _count_parameters(
- m: torch.nn.Module,
- module_name: str,
- count_shared_once: bool = False,
- only_trainable: bool = False,
- ):
- """
- returns the total number of parameters used by `m` (only counting
- shared parameters once); if `only_trainable` is True, then only
- includes parameters with `requires_grad = True`
- """
- parameters = m.parameters()
- if only_trainable:
- parameters = list(p for p in parameters if p.requires_grad)
- if count_shared_once:
- parameters = dict(
- (p.data_ptr(), p) for p in parameters
- ).values()
- number_of_parameters = sum(p.numel() for p in parameters)
-
- msg = (
- f"{module_name}: "
- f"number_of_parameters: {number_of_parameters} "
- f"(only_trainable: {only_trainable}, "
- f"count_shared_once: {count_shared_once})"
- )
- print(msg)
- logger.info(msg)
- return number_of_parameters
+ if log_once("model_summaries"):
+ ModelSummarizer.for_every_policy_print_model_stats(worker)
def on_episode_step(
self,
@@ -158,9 +81,11 @@ def on_episode_step(
worker: "RolloutWorker",
base_env: BaseEnv,
episode: MultiAgentEpisode,
- env_index: int,
+ env_index: Optional[int] = None,
**kwargs,
):
+ if log_from_policy_in_evaluation:
+ self._update_epi_info_wt_to_log(worker, episode)
if log_env_step:
self._add_env_info_to_custom_metrics(worker, episode)
if log_full_epi:
@@ -173,7 +98,7 @@ def on_episode_end(
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode,
- env_index: int,
+ env_index: Optional[int] = None,
**kwargs,
):
if log_full_epi:
@@ -192,6 +117,11 @@ def on_train_result(self, *, trainer, result: dict, **kwargs):
self._update_train_result_wt_to_log(
trainer, result, function_to_exec=get_log_from_policy
)
+ self._update_train_result_wt_to_log(
+ trainer,
+ result,
+ function_to_exec=get_explore_temperature_from_policy,
+ )
if log_weights:
if not hasattr(self, "on_train_result_counter"):
self.on_train_result_counter = 0
@@ -199,24 +129,10 @@ def on_train_result(self, *, trainer, result: dict, **kwargs):
self._update_train_result_wt_to_log(
trainer,
result,
- function_to_exec=self._get_weights_from_policy,
+ function_to_exec=get_weights_from_policy,
)
self.on_train_result_counter += 1
- @staticmethod
- def _get_weights_from_policy(
- policy: Policy, policy_id: PolicyID
- ) -> dict:
- """Gets the to_log var from a policy and rename its keys, adding the policy_id as a prefix."""
- to_log = {}
- weights = policy.get_weights()
-
- for k, v in weights.items():
- if isinstance(v, Iterable):
- to_log[f"{policy_id}/{k}"] = v
-
- return to_log
-
@staticmethod
def _add_env_info_to_custom_metrics(worker, episode):
@@ -239,10 +155,28 @@ def _update_train_result_wt_to_log(
def exec_in_each_policy(worker):
return worker.foreach_policy(function_to_exec)
- # to_log_list = trainer.workers.foreach_policy(function_to_exec)
to_log_list_list = trainer.workers.foreach_worker(
exec_in_each_policy
)
+ self._unroll_into_logs(result, to_log_list_list)
+
+ def _update_epi_info_wt_to_log(
+ self, worker, episode: MultiAgentEpisode
+ ):
+ """
+ Add logs from every policies (from policy.to_log:dict)
+ to the info (to be plotted in Tensorboard).
+ To be called from the on_episode_end callback.
+ """
+
+ for policy_id, policy in worker.policy_map.items():
+ to_log = get_log_from_policy(policy, policy_id)
+ episode._agent_to_last_info[policy_id].update(to_log)
+
+ @staticmethod
+ def _unroll_into_logs(
+ dict_: Dict, to_log_list_list: List[List[Dict]]
+ ) -> Dict:
for worker_idx, to_log_list in enumerate(to_log_list_list):
for to_log in to_log_list:
for k, v in to_log.items():
@@ -252,12 +186,13 @@ def exec_in_each_policy(worker):
else:
key = k
- if key not in result.keys():
- result[key] = v
+ if key not in dict_.keys():
+ dict_[key] = v
else:
- raise ValueError(
- f"key:{key} already exists in result.keys(): {result.keys()}"
+ logger.warning(
+ f"key:{key} already exists in result.keys()"
)
+ return dict_
return LoggingCallbacks
@@ -275,6 +210,33 @@ def get_log_from_policy(policy: Policy, policy_id: PolicyID) -> dict:
return to_log
+def get_explore_temperature_from_policy(
+ policy: Policy, policy_id: PolicyID
+) -> dict:
+ """
+ It is exist get the temperature from the exploration policy of a Policy
+ """
+ to_log = {}
+ if hasattr(policy, "exploration"):
+ exploration_obj = policy.exploration
+ if hasattr(exploration_obj, "temperature"):
+ to_log[f"{policy_id}/temperature"] = exploration_obj.temperature
+
+ return to_log
+
+
+def get_weights_from_policy(policy: Policy, policy_id: PolicyID) -> dict:
+ """Gets the to_log var from a policy and rename its keys, adding the policy_id as a prefix."""
+ to_log = {}
+ weights = policy.get_weights()
+
+ for k, v in weights.items():
+ if isinstance(v, Iterable):
+ to_log[f"{policy_id}/{k}"] = v
+
+ return to_log
+
+
def augment_stats_fn_wt_additionnal_logs(
stats_function: Callable[[Policy, SampleBatch], Dict[str, TensorType]]
):
@@ -296,7 +258,9 @@ def wt_additional_info(
if policy.config["framework"] == "torch":
stats_to_log.update(_log_action_prob_pytorch(policy, train_batch))
else:
- logger.warning("wt_additional_info workin only for PyTorch")
+ logger.warning(
+ "wt_additional_info (stats_fn) working only for PyTorch"
+ )
return stats_to_log
@@ -316,15 +280,20 @@ def _log_action_prob_pytorch(
# TODO add entropy
to_log = {}
if isinstance(policy.action_space, gym.spaces.Discrete):
-
- assert (
- train_batch["action_dist_inputs"].dim() == 2
- ), "Do not support nested discrete spaces"
-
- to_log = _add_action_distrib_to_log(policy, train_batch, to_log)
- to_log = _add_entropy_to_log(train_batch, to_log)
- to_log = _add_proba_of_action_played(train_batch, to_log)
- to_log = _add_q_values(policy, train_batch, to_log)
+ if train_batch.ACTION_DIST_INPUTS in train_batch.keys():
+ assert (
+ train_batch[train_batch.ACTION_DIST_INPUTS].dim() == 2
+ ), "Do not support nested discrete spaces"
+
+ to_log = _add_action_distrib_to_log(policy, train_batch, to_log)
+ to_log = add_entropy_to_log(train_batch, to_log)
+ to_log = _add_proba_of_action_played(train_batch, to_log)
+ to_log = _add_q_values(policy, train_batch, to_log)
+ else:
+ logger.warning(
+ "Key ACTION_DIST_INPUTS not found in train_batch. "
+ "Can't perform _log_action_prob_pytorch."
+ )
else:
raise NotImplementedError()
return to_log
@@ -344,7 +313,7 @@ def _add_action_distrib_to_log(policy, train_batch, to_log):
return to_log
-def _add_entropy_to_log(train_batch, to_log):
+def add_entropy_to_log(train_batch, to_log):
actions_proba_batch = train_batch["action_dist_inputs"]
if _is_cuda_tensor(actions_proba_batch):
@@ -357,8 +326,7 @@ def _add_entropy_to_log(train_batch, to_log):
actions_proba_batch
)
- entropy_avg = _entropy_batch_proba_distrib(actions_proba_batch)
- entropy_single = _entropy_proba_distrib(actions_proba_batch[-1, :])
+ entropy_avg, entropy_single = _compute_entropy_pytorch(actions_proba_batch)
to_log[f"entropy_buffer_samples_avg"] = entropy_avg
to_log[f"entropy_buffer_samples_single"] = entropy_single
@@ -371,24 +339,7 @@ def _is_cuda_tensor(tensor):
def _entropy_batch_proba_distrib(proba_distrib_batch):
assert len(proba_distrib_batch) > 0
- entropy_batch = [
- _entropy_proba_distrib(proba_distrib_batch[batch_idx, ...])
- for batch_idx in range(len(proba_distrib_batch))
- ]
- mean_entropy = sum(entropy_batch) / len(entropy_batch)
- return mean_entropy
-
-
-def _entropy_proba_distrib(proba_distrib):
- return sum([_entropy_proba(proba) for proba in proba_distrib])
-
-
-def _entropy_proba(proba):
- assert proba >= 0.0, f"proba currently is {proba}"
- if proba == 0.0:
- return 0.0
- else:
- return -proba * math.log(proba)
+ return Categorical(probs=proba_distrib_batch).entropy()
def _add_proba_of_action_played(train_batch, to_log):
@@ -399,7 +350,7 @@ def _add_proba_of_action_played(train_batch, to_log):
def _convert_q_values_batch_to_proba_batch(q_values_batch):
- return softmax(q_values_batch, axis=1)
+ return torch.nn.functional.softmax(q_values_batch, dim=1)
def _add_q_values(policy, train_batch, to_log):
@@ -416,15 +367,21 @@ def _add_q_values(policy, train_batch, to_log):
return to_log
-def _compute_entropy_from_raw_q_values(policy, q_values):
+def compute_entropy_from_raw_q_values(policy, q_values):
actions_proba_batch = _apply_exploration(policy, dist_inputs=q_values)
- if _is_cuda_tensor(actions_proba_batch):
- actions_proba_batch = actions_proba_batch.cpu()
actions_proba_batch = _convert_q_values_batch_to_proba_batch(
actions_proba_batch
)
- entropy_avg = _entropy_batch_proba_distrib(actions_proba_batch)
- entropy_single = _entropy_proba_distrib(actions_proba_batch[-1, :])
+ return _compute_entropy_pytorch(actions_proba_batch)
+
+
+def _compute_entropy_pytorch(actions_proba_batch):
+ entropies = _entropy_batch_proba_distrib(actions_proba_batch)
+ entropy_avg = entropies.mean()
+ entropy_single = entropies[-1]
+ if _is_cuda_tensor(actions_proba_batch):
+ entropy_avg = entropy_avg.cpu()
+ entropy_single = entropy_single.cpu()
return entropy_avg, entropy_single
@@ -500,13 +457,15 @@ def pprint_saved_metrics(file_path, keywords_to_print=None):
pp.pprint(metrics)
-def _log_learning_rate(policy):
+def log_learning_rate(policy):
to_log = {}
if hasattr(policy, "cur_lr"):
to_log["cur_lr"] = policy.cur_lr
for j, opt in enumerate(policy._optimizers):
if hasattr(opt, "param_groups"):
to_log[f"opt{j}_lr"] = [p["lr"] for p in opt.param_groups][0]
+ else:
+ print("opt doesn't have attr param_groups")
return to_log
diff --git a/marltoolbox/utils/log/model_summarizer.py b/marltoolbox/utils/log/model_summarizer.py
new file mode 100644
index 0000000..5aa1380
--- /dev/null
+++ b/marltoolbox/utils/log/model_summarizer.py
@@ -0,0 +1,100 @@
+import logging
+
+import torch
+from ray.rllib.policy.torch_policy import TorchPolicy
+from torch.nn import Module
+from ray.rllib.evaluation import RolloutWorker
+
+logger = logging.getLogger(__name__)
+
+
+class ModelSummarizer:
+ """
+ Helper to log for every torch.nn modules in every policies the
+ architecture and some parameter statistics.
+ """
+
+ @staticmethod
+ def for_every_policy_print_model_stats(worker: RolloutWorker):
+ """
+ For every policies in the worker, log the archi of all torch modules
+ and some statistiques about their parameters
+
+ :param worker:
+ """
+ for policy_id, policy in worker.policy_map.items():
+ msg = f"===== Models summaries policy_id {policy_id} ====="
+ print(msg)
+ logger.info(msg)
+ ModelSummarizer._print_model_summary(policy)
+ ModelSummarizer._count_parameters_in_every_modules(policy)
+
+ @staticmethod
+ def _print_model_summary(policy: TorchPolicy):
+ if isinstance(policy, TorchPolicy):
+ for k, v in policy.__dict__.items():
+ if isinstance(v, Module):
+ msg = f"{k}, {v}"
+ print(msg)
+ logger.info(msg)
+
+ @staticmethod
+ def _count_parameters_in_every_modules(policy: TorchPolicy):
+ if isinstance(policy, TorchPolicy):
+ for k, v in policy.__dict__.items():
+ if isinstance(v, Module):
+ ModelSummarizer._count_and_log_for_one_module(policy, k, v)
+
+ @staticmethod
+ def _count_and_log_for_one_module(
+ policy: TorchPolicy, module_name: str, module: torch.nn.Module
+ ):
+ n_param = ModelSummarizer._count_parameters(module, module_name)
+ n_param_shared_counted_once = ModelSummarizer._count_parameters(
+ module, module_name, count_shared_once=True
+ )
+ n_param_trainable = ModelSummarizer._count_parameters(
+ module, module_name, only_trainable=True
+ )
+ ModelSummarizer._log_values_in_to_log(
+ policy,
+ {
+ f"module_{module_name}_n_param": n_param,
+ f"module_{module_name}_n_param_shared_counted_once": n_param_shared_counted_once,
+ f"module_{module_name}_n_param_trainable": n_param_trainable,
+ },
+ )
+
+ @staticmethod
+ def _log_values_in_to_log(policy, dictionary):
+ if hasattr(policy, "to_log"):
+ policy.to_log.update(dictionary)
+
+ @staticmethod
+ def _count_parameters(
+ m: torch.nn.Module,
+ module_name: str,
+ count_shared_once: bool = False,
+ only_trainable: bool = False,
+ ):
+ """
+ returns the total number of parameters used by `m` (only counting
+ shared parameters once); if `only_trainable` is True, then only
+ includes parameters with `requires_grad = True`
+ """
+ parameters = m.parameters()
+ if only_trainable:
+ parameters = list(p for p in parameters if p.requires_grad)
+ if count_shared_once:
+ parameters = dict((p.data_ptr(), p) for p in parameters).values()
+ number_of_parameters = sum(p.numel() for p in parameters)
+
+ msg = (
+ f"{module_name}: "
+ f"number_of_parameters: {number_of_parameters} "
+ f"(only_trainable: {only_trainable}, "
+ f"count_shared_once: {count_shared_once})"
+ )
+ print(msg)
+ logger.info(msg)
+ return number_of_parameters
diff --git a/marltoolbox/utils/miscellaneous.py b/marltoolbox/utils/miscellaneous.py
index 290bf52..5890de9 100644
--- a/marltoolbox/utils/miscellaneous.py
+++ b/marltoolbox/utils/miscellaneous.py
@@ -71,7 +71,7 @@ def move_to_key(dict_: dict, key: str):
:param dict_: dict or nesyed dict
:param key: key or serie of key joined by a '.'
- :return: (the lower level dict, lower level key, the final value,
+ :return: Tuple(the lower level dict, lower level key, the final value,
boolean for final value found)
"""
assert isinstance(dict_, dict)
@@ -79,7 +79,10 @@ def move_to_key(dict_: dict, key: str):
found = True
for k in key.split("."):
if not found:
- print(f"Intermediary key: {k} not found in full key: {key}")
+ print(
+ f"Intermediary key: {k} not found with full key: {key} "
+ f"and dict: {dict_}"
+ )
return
dict_ = current_value
if k in current_value.keys():
@@ -89,42 +92,6 @@ def move_to_key(dict_: dict, key: str):
return dict_, k, current_value, found
-def extract_checkpoints(tune_experiment_analysis):
- logger.info("start extract_checkpoints")
-
- for trial in tune_experiment_analysis.trials:
- checkpoints = tune_experiment_analysis.get_trial_checkpoints_paths(
- trial, tune_experiment_analysis.default_metric
- )
- assert len(checkpoints) > 0
-
- all_best_checkpoints_per_trial = [
- tune_experiment_analysis.get_best_checkpoint(
- trial,
- metric=tune_experiment_analysis.default_metric,
- mode=tune_experiment_analysis.default_mode,
- )
- for trial in tune_experiment_analysis.trials
- ]
-
- for checkpoint in all_best_checkpoints_per_trial:
- assert checkpoint is not None
-
- logger.info("end extract_checkpoints")
- return all_best_checkpoints_per_trial
-
-
-def extract_config_values_from_tune_analysis(tune_experiment_analysis, key):
- values = []
- for trial in tune_experiment_analysis.trials:
- dict_, k, current_value, found = move_to_key(trial.config, key)
- if found:
- values.append(current_value)
- else:
- values.append(None)
- return values
-
-
def merge_policy_postprocessing_fn(*postprocessing_fn_list):
"""
Merge several callback class together.
@@ -192,74 +159,12 @@ def set_config_for_evaluation(
return config_copy
-def filter_tune_results(
- tune_analysis,
- metric,
- metric_threshold: float,
- metric_mode="last-5-avg",
- threshold_mode="above",
-):
- assert threshold_mode in ("above", "equal", "below")
- assert metric_mode in (
- "avg",
- "min",
- "max",
- "last",
- "last-5-avg",
- "last-10-avg",
- )
- print("Before trial filtering:", len(tune_analysis.trials), "trials")
- trials_filtered = []
- print(
- "metric_threshold", metric_threshold, "threshold_mode", threshold_mode
- )
- for trial_idx, trial in enumerate(tune_analysis.trials):
- available_metrics = trial.metric_analysis
- print(
- f"trial_idx {trial_idx} "
- f"available_metrics[{metric}][{metric_mode}] "
- f"{available_metrics[metric][metric_mode]}"
- )
- if (
- threshold_mode == "above"
- and available_metrics[metric][metric_mode] > metric_threshold
- ):
- trials_filtered.append(trial)
- elif (
- threshold_mode == "equal"
- and available_metrics[metric][metric_mode] == metric_threshold
- ):
- trials_filtered.append(trial)
- elif (
- threshold_mode == "below"
- and available_metrics[metric][metric_mode] < metric_threshold
- ):
- trials_filtered.append(trial)
- else:
- print(f"filter trial {trial_idx}")
- tune_analysis.trials = trials_filtered
- print("After trial filtering:", len(tune_analysis.trials), "trials")
- return tune_analysis
-
-
def get_random_seeds(n_seeds):
timestamp = int(time.time())
seeds = [seed + timestamp for seed in list(range(n_seeds))]
return seeds
-def list_all_files_in_one_dir_tree(path):
- if not os.path.exists(path):
- raise FileExistsError(f"path doesn't exist: {path}")
- file_list = []
- for root, dirs, files in os.walk(path):
- for file in files:
- # append the file name to the list
- file_list.append(os.path.join(root, file))
- print(len(file_list), "files found")
- return file_list
-
-
def ignore_str_containing_keys(str_list, ignore_keys):
str_list_filtered = [
file_path
@@ -368,41 +273,6 @@ def _get_experiment_state_file_path(one_checkpoint_path, split_path_n_times=1):
return json_file_path
-def check_learning_achieved(
- tune_results,
- metric="episode_reward_mean",
- trial_idx=0,
- max_: float = None,
- min_: float = None,
- equal_: float = None,
-):
- assert max_ is not None or min_ is not None or equal_ is not None
-
- last_results = tune_results.trials[trial_idx].last_result
- _, _, value, found = move_to_key(last_results, key=metric)
- assert (
- found
- ), f"metric {metric} not found inside last_results {last_results}"
-
- msg = (
- f"Trial {trial_idx} achieved "
- f"{value}"
- f" on metric {metric}. This is a success if the value is below"
- f" {max_} or above {min_} or equal to {equal_}."
- )
-
- logger.info(msg)
- print(msg)
- if min_ is not None:
- assert value >= min_, f"value {value} must be above min_ {min_}"
- if max_ is not None:
- assert value <= max_, f"value {value} must be below max_ {max_}"
- if equal_ is not None:
- assert value == equal_, (
- f"value {value} must be equal to equal_ " f"{equal_}"
- )
-
-
def assert_if_key_in_dict_then_args_are_none(dict_, key, *args):
if key in dict_.keys():
for arg in args:
@@ -422,28 +292,12 @@ def read_from_dict_default_to_args(dict_, key, *args):
def filter_sample_batch(
samples: SampleBatch, filter_key, remove=True, copy_data=False
) -> SampleBatch:
- filter = samples.data[filter_key]
+ filter = samples.columns([filter_key])[0]
if remove:
- # torch logical not
+ assert isinstance(
+ filter, np.ndarray
+ ), f"type {type(filter)} for filter_key {filter_key}"
filter = ~filter
return SampleBatch(
- {
- k: np.array(v, copy=copy_data)[filter]
- for (k, v) in samples.data.items()
- }
+ {k: np.array(v, copy=copy_data)[filter] for (k, v) in samples.items()}
)
-
-
-def extract_metric_values_per_trials(
- tune_analysis,
- metric="episode_reward_mean",
-):
- metric_values = []
- for trial in tune_analysis.trials:
- last_results = trial.last_result
- _, _, value, found = move_to_key(last_results, key=metric)
- assert (
- found
- ), f"metric: {metric} not found in last_results: {last_results}"
- metric_values.append(value)
- return metric_values
diff --git a/marltoolbox/utils/path.py b/marltoolbox/utils/path.py
new file mode 100644
index 0000000..106ae32
--- /dev/null
+++ b/marltoolbox/utils/path.py
@@ -0,0 +1,197 @@
+import json
+import os
+from typing import List
+
+from marltoolbox.utils import miscellaneous
+from marltoolbox.utils.tune_analysis import ABOVE, BELOW, EQUAL
+
+
+def get_unique_child_dir(_dir: str):
+ """
+ Return the path to the unique dir inside the given dir.
+
+ :param _dir: path to given dir
+ :return: path to the unique dir inside the given dir
+ """
+
+ list_child_dir = os.listdir(_dir)
+ list_child_dir = [
+ os.path.join(_dir, child_dir) for child_dir in list_child_dir
+ ]
+ list_child_dir = keep_dirs_only(list_child_dir)
+ assert len(list_child_dir) == 1, f"{list_child_dir}"
+ unique_child_dir = list_child_dir[0]
+ return unique_child_dir
+
+
+def try_get_unique_child_dir(_dir: str):
+ """
+ If it exists, returns the path to the unique dir inside the given dir.
+ Otherwise returns None.
+
+ :param _dir: path to given dir
+ :return: path to the unique dir inside the given dir or if it doesn't
+ exist None
+ """
+
+ try:
+ unique_child_dir = get_unique_child_dir(_dir)
+ return unique_child_dir
+ except AssertionError:
+ return None
+
+
+def list_all_files_in_one_dir_tree(path: str) -> List[str]:
+ """
+ List all the files in the tree starting at the given path.
+
+ :param path:
+ :return: list of all the files
+ """
+ if not os.path.exists(path):
+ raise FileExistsError(f"path doesn't exist: {path}")
+ file_list = []
+ for root, dirs, files in os.walk(path):
+ for file in files:
+ # append the file name to the list
+ file_list.append(os.path.join(root, file))
+ print(len(file_list), "files found")
+ return file_list
+
+
+def get_children_paths_wt_selecting_filter(
+ parent_dir_path: str, _filter: str
+) -> List[str]:
+ """
+ Return all children dir paths after selecting those containing the
+ _filter.
+
+ :param parent_dir_path:
+ :param _filter: to select the paths to keep
+ :return: list of paths which contain the given filter.
+ """
+ return _get_children_paths_filters(
+ parent_dir_path, selecting_filter=_filter
+ )
+
+
+def get_children_paths_wt_discarding_filter(
+ parent_dir_path: str, _filter: str
+) -> List[str]:
+ """
+ Return all children dir paths after selecting those NOT containing the
+ _filter.
+
+ :param parent_dir_path:
+ :param _filter: to select the paths to remove
+ :return: list of paths which don't contain the given filter.
+ """
+
+ return _get_children_paths_filters(
+ parent_dir_path, discarding_filter=_filter
+ )
+
+
+def _get_children_paths_filters(
+ parent_dir_path: str,
+ selecting_filter: str = None,
+ discarding_filter: str = None,
+):
+ filtered_children = os.listdir(parent_dir_path)
+ if selecting_filter is not None:
+ filtered_children = [
+ filename
+ for filename in filtered_children
+ if selecting_filter in filename
+ ]
+ if discarding_filter is not None:
+ filtered_children = [
+ filename
+ for filename in filtered_children
+ if discarding_filter not in filename
+ ]
+ filtered_children_path = [
+ os.path.join(parent_dir_path, filename)
+ for filename in filtered_children
+ ]
+ return filtered_children_path
+
+
+def get_params_for_replicate(trial_dir_path: str) -> dict:
+ """
+ Get the parameters from the json file saved in the dir of an Tune/RLLib
+ trial.
+
+ :param trial_dir_path: patht to a single tune.Trial (inside an experiment)
+ :return: dict of parameters used for the trial
+ """
+ parameter_json_path = os.path.join(trial_dir_path, "params.json")
+ params = _read_json_file(parameter_json_path)
+ return params
+
+
+def get_results_for_replicate(trial_dir_path: str) -> list:
+ """
+ Get the results for all episodes from the file saved in the
+ dir of an Tune/RLLib trial.
+
+ :param trial_dir_path: patht to a single tune.Trial (inside an experiment)
+ :return: list of lines of results (one line per episode)
+ """
+ results_file_path = os.path.join(trial_dir_path, "result.json")
+ results = _read_all_lines_of_file(results_file_path)
+ # Remove empty last line
+ if len(results[-1]) == 0:
+ results = results[:-1]
+ results = [json.loads(line) for line in results]
+ return results
+
+
+def _read_json_file(json_file_path: str):
+ with open(json_file_path) as json_file:
+ json_object = json.load(json_file)
+ return json_object
+
+
+def _read_all_lines_of_file(file_path: str) -> list:
+ with open(file_path) as file:
+ lines = list(file)
+ return lines
+
+
+def keep_dirs_only(paths: list) -> list:
+ """Keep only the directories"""
+ return [path for path in paths if os.path.isdir(path)]
+
+
+def filter_list_of_replicates_by_results(
+ replicate_paths: list,
+ filter_key: str,
+ filter_threshold: float,
+ filter_mode: str = ABOVE,
+) -> list:
+ print("Going to start filtering replicate_paths")
+ print("len(replicate_paths)", len(replicate_paths))
+ filtered_replicate_paths = []
+ for replica_path in replicate_paths:
+ replica_results = get_results_for_replicate(replica_path)
+ last_result = replica_results[-1]
+ assert isinstance(last_result, dict)
+ _, _, current_value, found = miscellaneous.move_to_key(
+ last_result, filter_key
+ )
+ assert found, (
+ f"filter_key {filter_key} not found in last_result "
+ f"{last_result}"
+ )
+ if filter_mode == ABOVE and current_value > filter_threshold:
+ filtered_replicate_paths.append(replica_path)
+ elif filter_mode == EQUAL and current_value == filter_threshold:
+ filtered_replicate_paths.append(replica_path)
+ elif filter_mode == BELOW and current_value < filter_threshold:
+ filtered_replicate_paths.append(replica_path)
+ else:
+ print(f"filtering out replica_path {replica_path}")
+ print("After filtering:")
+ print("len(filtered_replicate_paths)", len(filtered_replicate_paths))
+ return filtered_replicate_paths
diff --git a/marltoolbox/utils/plot.py b/marltoolbox/utils/plot.py
index f478894..0b09b15 100644
--- a/marltoolbox/utils/plot.py
+++ b/marltoolbox/utils/plot.py
@@ -5,6 +5,10 @@
import matplotlib.pyplot as plt
import numpy as np
+plt.switch_backend("agg")
+plt.style.use("seaborn-whitegrid")
+plt.rcParams.update({"font.size": 12})
+
COLORS = list(mcolors.TABLEAU_COLORS) + list(mcolors.XKCD_COLORS)
RANDOM_MARKERS = ["1", "2", "3", "4", "8", "s", "p", "P", "*", "h", "H", "+"]
MARKERS = ["o", "s", "v", "^", "<", ">", "P", "X", "D", "*"] + RANDOM_MARKERS
@@ -20,7 +24,7 @@ def __init__(
xlabel: str = None,
ylabel: str = None,
display_legend: bool = True,
- legend_fontsize: str = "small",
+ legend_fontsize: str = "medium",
save_dir_path: str = None,
title: str = None,
xlim: str = None,
@@ -74,6 +78,7 @@ def __init__(
class PlotHelper:
def __init__(self, plot_config: PlotConfig):
self.plot_cfg = plot_config
+ self.additional_filename_suffix = ""
def plot_lines(self, data_groups: dict):
"""
@@ -157,13 +162,32 @@ def _get_label(self, group_id, col):
label = group_id
return label
- def _finalize_plot(self, fig):
+ def _finalize_plot(self, fig, remove_err_bars_from_legend=False):
if self.plot_cfg.display_legend:
- plt.legend(
- numpoints=1,
- frameon=True,
- fontsize=self.plot_cfg.legend_fontsize,
- )
+ if remove_err_bars_from_legend:
+ ax = plt.gca()
+ # get handles
+ handles, labels = ax.get_legend_handles_labels()
+ # remove the errorbars
+ handles = [h[0] for h in handles]
+ # use them in the legend
+ ax.legend(
+ handles,
+ labels,
+ numpoints=1,
+ frameon=True,
+ fontsize=self.plot_cfg.legend_fontsize,
+ )
+ else:
+ plt.legend(
+ numpoints=1,
+ frameon=True,
+ fontsize=self.plot_cfg.legend_fontsize,
+ )
+
+ # bbox_to_anchor = (0.66, -0.20),
+ # # loc="upper left",
+ # )
if self.plot_cfg.xlabel is not None:
plt.xlabel(self.plot_cfg.xlabel)
if self.plot_cfg.ylabel is not None:
@@ -176,10 +200,11 @@ def _finalize_plot(self, fig):
plt.ylim(self.plot_cfg.ylim)
if self.plot_cfg.background_area_coord is not None:
self._add_background_area()
+ # plt.tight_layout(rect=[0, -0.05, 1.0, 1.0])
if self.plot_cfg.save_dir_path is not None:
file_name = (
f"{self.plot_cfg.filename_prefix}_{self.plot_cfg.ylabel}_vs"
- f"_{self.plot_cfg.xlabel}.png"
+ f"_{self.plot_cfg.xlabel}{self.additional_filename_suffix}.png"
)
file_name = file_name.replace("/", "_")
file_path = os.path.join(self.plot_cfg.save_dir_path, file_name)
@@ -195,6 +220,19 @@ def plot_dots(self, data_groups: dict):
:param data_groups: dict of groups (same color and label prefix) containing a DataFrame containing (x,
y) tuples. Each column in a group DataFrame has a different marker.
"""
+ self._plot_dots_multiple_points(data_groups)
+ self._plot_dots_one_point_wt_std_dev_bars(data_groups)
+
+ def _plot_dots_multiple_points(self, data_groups: dict):
+ return self._plot_dots_customizable(data_groups)
+
+ def _plot_dots_one_point_wt_std_dev_bars(self, data_groups: dict):
+ self.additional_filename_suffix = "_wt_err_bars"
+ plot_filname = self._plot_dots_customizable(data_groups, err_bars=True)
+ self.additional_filename_suffix = ""
+ return plot_filname
+
+ def _plot_dots_customizable(self, data_groups: dict, err_bars=False):
fig = self._init_plot()
self.counter_labels = 0
@@ -203,47 +241,96 @@ def plot_dots(self, data_groups: dict):
data_groups.items()
):
new_labels_plotted = self._plot_dotes_for_one_group(
- self.plot_cfg.colors[group_index], group_id, group_df
+ self.plot_cfg.colors[group_index], group_id, group_df, err_bars
)
all_label_plotted.extend(new_labels_plotted)
print("all_label_plotted", all_label_plotted)
- return self._finalize_plot(fig)
+ return self._finalize_plot(fig, remove_err_bars_from_legend=err_bars)
- def _plot_dotes_for_one_group(self, group_color, group_id, group_df):
+ def _plot_dotes_for_one_group(
+ self, group_color, group_id, group_df, err_bars=False
+ ):
label_plotted = []
for col in group_df.columns:
- x, y = self._select_n_points_to_plot(group_df, col)
- x, y = self._add_jitter_to_points(x, y)
- x, y = self._apply_scale_multiplier(x, y)
- label = self._get_label(group_id, col)
-
- plt.plot(
- x,
- y,
- markerfacecolor="none"
- if self.plot_cfg.empty_markers
- else group_color,
- markeredgecolor=group_color,
- linestyle="None",
- marker=self.plot_cfg.markers[self.counter_labels],
- color=group_color,
- label=label,
- alpha=self.plot_cfg.alpha,
- markersize=self.plot_cfg.markersize,
- )
- self.counter_labels += 1
+ if not err_bars:
+ label = self._plot_wtout_err_bars(
+ group_df, col, group_id, group_color
+ )
+ else:
+ label = self._plot_wt_err_bars(
+ group_df, col, group_id, group_color
+ )
label_plotted.append(label)
return label_plotted
+ def _plot_wtout_err_bars(self, group_df, col, group_id, group_color):
+ x, y = self._select_n_points_to_plot(group_df, col)
+ x, y = self._add_jitter_to_points(x, y)
+ x, y = self._apply_scale_multiplier(x, y)
+ label = self._get_label(group_id, col)
+
+ plt.plot(
+ x,
+ y,
+ markerfacecolor="none"
+ if self.plot_cfg.empty_markers
+ else group_color,
+ markeredgecolor=group_color,
+ linestyle="None",
+ marker=self.plot_cfg.markers[self.counter_labels],
+ color=group_color,
+ label=label,
+ alpha=self.plot_cfg.alpha,
+ markersize=self.plot_cfg.markersize,
+ )
+ self.counter_labels += 1
+ return label
+
+ def _plot_wt_err_bars(self, group_df, col, group_id, group_color):
+ x, y = self._select_all_points_to_plot(group_df, col)
+ x, y = self._apply_scale_multiplier(x, y)
+ label = self._get_label(group_id, col)
+ x_mean = np.array(x).mean()
+ y_mean = np.array(y).mean()
+ x_std_err = np.array(x).std() / np.sqrt(len(x))
+ y_std_err = np.array(y).std() / np.sqrt(len(y))
+
+ plt.errorbar(
+ x_mean,
+ y_mean,
+ xerr=x_std_err,
+ yerr=y_std_err,
+ markerfacecolor="none"
+ if self.plot_cfg.empty_markers
+ else group_color,
+ markeredgecolor=group_color,
+ linestyle="None",
+ marker=self.plot_cfg.markers[self.counter_labels],
+ color=group_color,
+ label=label,
+ alpha=self.plot_cfg.alpha,
+ markersize=36 * 2.0
+ if self.plot_cfg.markersize is None
+ else self.plot_cfg.markersize * 2.0,
+ )
+ self.counter_labels += 1
+ return label
+
def _select_n_points_to_plot(self, group_df, col):
if self.plot_cfg.plot_max_n_points is not None:
n_points_to_plot = min(
self.plot_cfg.plot_max_n_points, len(group_df)
)
print(f"Selected {n_points_to_plot} n_points_to_plot")
+ return self._get_points_to_plot(group_df, n_points_to_plot, col)
else:
- n_points_to_plot = len(group_df)
+ return self._select_all_points_to_plot(group_df, col)
+
+ def _select_all_points_to_plot(self, group_df, col):
+ return self._get_points_to_plot(group_df, len(group_df), col)
+
+ def _get_points_to_plot(self, group_df, n_points_to_plot, col):
group_df_sample = group_df.sample(n=int(n_points_to_plot))
points = group_df_sample[col].tolist()
x, y = [p[0] for p in points], [p[1] for p in points]
diff --git a/marltoolbox/utils/policy.py b/marltoolbox/utils/policy.py
index c7a59a2..40ac964 100644
--- a/marltoolbox/utils/policy.py
+++ b/marltoolbox/utils/policy.py
@@ -5,7 +5,6 @@
from ray.rllib.policy.torch_policy import LearningRateSchedule
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
from ray.rllib.utils.typing import TrainerConfigDict
-
from marltoolbox.utils.restore import LOAD_FROM_CONFIG_KEY
@@ -21,39 +20,53 @@ def get_tune_policy_class(PolicyClass):
"""
class FrozenPolicyFromTuneTrainer(PolicyClass):
-
- def __init__(self, observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- config: TrainerConfigDict):
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ config: TrainerConfigDict,
+ ):
print("__init__ FrozenPolicyFromTuneTrainer")
self.tune_config = config["tune_config"]
TuneTrainerClass = self.tune_config["TuneTrainerClass"]
self.tune_trainer = TuneTrainerClass(config=self.tune_config)
self.load_checkpoint(
- config.pop(LOAD_FROM_CONFIG_KEY, (None, None)))
+ config.pop(LOAD_FROM_CONFIG_KEY, (None, None))
+ )
+ self._to_log = {}
super().__init__(observation_space, action_space, config)
- def compute_actions(self,
- obs_batch,
- state_batches=None,
- prev_action_batch=None,
- prev_reward_batch=None,
- info_batch=None,
- episodes=None,
- **kwargs):
- actions, state_out, extra_fetches = \
- self.tune_trainer.compute_actions(self.policy_id, obs_batch)
+ def _compute_action_helper(
+ self,
+ input_dict,
+ *args,
+ **kwargs,
+ ):
+ # print('input_dict["obs"]', input_dict["obs"])
+ (
+ actions,
+ state_out,
+ extra_fetches,
+ ) = self.tune_trainer.compute_actions(
+ self.policy_id, input_dict["obs"]
+ )
return actions, state_out, extra_fetches
+ def _initialize_loss_from_dummy_batch(self, *args, **kwargs):
+ pass
+
def learn_on_batch(self, samples):
raise NotImplementedError(
- "FrozenPolicyFromTuneTrainer policy can't be trained")
+ "FrozenPolicyFromTuneTrainer policy can't be trained"
+ )
def get_weights(self):
- return {"checkpoint_path": self.checkpoint_path,
- "policy_id": self.policy_id}
+ return {
+ "checkpoint_path": self.checkpoint_path,
+ "policy_id": self.policy_id,
+ }
def set_weights(self, weights):
checkpoint_path = weights["checkpoint_path"]
@@ -65,6 +78,27 @@ def load_checkpoint(self, checkpoint_tuple):
if self.checkpoint_path is not None:
self.tune_trainer.load_checkpoint(self.checkpoint_path)
+ @property
+ def to_log(self):
+ to_log = {
+ "frozen_policy": self._to_log,
+ "nested_tune_policy": {
+ f"policy_{algo_idx}": algo.to_log
+ for algo_idx, algo in enumerate(self.algorithms)
+ if hasattr(algo, "to_log")
+ },
+ }
+ return to_log
+
+ @to_log.setter
+ def to_log(self, value):
+ if value == {}:
+ for algo in self.algorithms:
+ if hasattr(algo, "to_log"):
+ algo.to_log = {}
+
+ self._to_log = value
+
return FrozenPolicyFromTuneTrainer
@@ -81,14 +115,17 @@ def __init__(self, lr, lr_schedule):
else:
if isinstance(lr_schedule, Iterable):
self.lr_schedule = PiecewiseSchedule(
- lr_schedule, outside_value=lr_schedule[-1][-1],
- framework=None)
+ lr_schedule,
+ outside_value=lr_schedule[-1][-1],
+ framework=None,
+ )
else:
self.lr_schedule = lr_schedule
-def my_setup_early_mixins(policy: Policy, obs_space, action_space,
- config: TrainerConfigDict) -> None:
- MyLearningRateSchedule.__init__(policy,
- config["lr"],
- config["lr_schedule"])
+def my_setup_early_mixins(
+ policy: Policy, obs_space, action_space, config: TrainerConfigDict
+) -> None:
+ MyLearningRateSchedule.__init__(
+ policy, config["lr"], config["lr_schedule"]
+ )
diff --git a/marltoolbox/utils/postprocessing.py b/marltoolbox/utils/postprocessing.py
index 894e3e1..76c0069 100644
--- a/marltoolbox/utils/postprocessing.py
+++ b/marltoolbox/utils/postprocessing.py
@@ -4,14 +4,17 @@
import numpy as np
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.evaluation import MultiAgentEpisode
-from ray.rllib.evaluation.postprocessing import discount
+from ray.rllib.evaluation.postprocessing import discount_cumsum
from ray.rllib.evaluation.sampler import _get_or_raise
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import AgentID, PolicyID
+from ray.rllib.utils.schedules import Schedule
-from marltoolbox.utils.miscellaneous import \
- assert_if_key_in_dict_then_args_are_none, read_from_dict_default_to_args
+from marltoolbox.utils.miscellaneous import (
+ assert_if_key_in_dict_then_args_are_none,
+ read_from_dict_default_to_args,
+)
if TYPE_CHECKING:
from ray.rllib.evaluation import RolloutWorker
@@ -37,22 +40,25 @@
ADD_OPPONENT_ACTION = "add_opponent_action"
ADD_OPPONENT_NEG_REWARD = "add_opponent_neg_reward"
-ADD_WELFARE_CONFIG_KEYS = (ADD_UTILITARIAN_WELFARE,
- ADD_INEQUITY_AVERSION_WELFARE)
+ADD_WELFARE_CONFIG_KEYS = (
+ ADD_UTILITARIAN_WELFARE,
+ ADD_INEQUITY_AVERSION_WELFARE,
+)
def welfares_postprocessing_fn(
- add_utilitarian_welfare: bool = None,
- add_egalitarian_welfare: bool = None,
- add_nash_welfare: bool = None,
- add_opponent_action: bool = None,
- add_opponent_neg_reward: bool = None,
- add_inequity_aversion_welfare: bool = None,
- inequity_aversion_alpha: float = None,
- inequity_aversion_beta: float = None,
- inequity_aversion_gamma: float = None,
- inequity_aversion_lambda: float = None,
- additional_fn: list = []):
+ add_utilitarian_welfare: bool = None,
+ add_egalitarian_welfare: bool = None,
+ add_nash_welfare: bool = None,
+ add_opponent_action: bool = None,
+ add_opponent_neg_reward: bool = None,
+ add_inequity_aversion_welfare: bool = None,
+ inequity_aversion_alpha: float = None,
+ inequity_aversion_beta: float = None,
+ inequity_aversion_gamma: float = None,
+ inequity_aversion_lambda: float = None,
+ additional_fn: list = [],
+):
"""
Generate a postprocess_fn that first add a welfare if you chose so and
then call a list of additional postprocess_fn to further modify the
@@ -60,7 +66,10 @@ def welfares_postprocessing_fn(
The parameters used to add a welfare can be given as arguments or
will be read in the policy config dict (this should be preferred since this
- allows for hyperparameter search over these parameters with Tune).
+ allows for hyperparameter search over these parameters with Tune).When
+ read from the config, they are read from the keys defined by
+ ADD_INEQUITY_AVERSION_WELFARE and similars. The value oassociated with
+ this key must be a tuple if you provide several parameters.
:param add_utilitarian_welfare:
:param add_egalitarian_welfare:
@@ -80,121 +89,202 @@ def welfares_postprocessing_fn(
:return:
"""
+ # TODO refactor into instanciating an object (from a new class) and
+ # returning one of it method OR use named tuple
def postprocess_fn(policy, sample_batch, other_agent_batches, episode):
+ if other_agent_batches is None:
+ logger.warning(
+ "no other_agent_batches given for welfare " "postprocessing"
+ )
+ return sample_batch
_assert_using_config_xor_args(policy)
parameters = _read_parameters_from_config_default_to_args(policy)
+ parameters = _get_values_from_any_scheduler(parameters, policy)
sample_batch = _add_welfare_to_own_batch(
- sample_batch, other_agent_batches, episode, policy, *parameters)
+ sample_batch, other_agent_batches, episode, policy, *parameters
+ )
sample_batch = _call_list_of_additional_fn(
- additional_fn, sample_batch, other_agent_batches, episode, policy)
+ additional_fn, sample_batch, other_agent_batches, episode, policy
+ )
return sample_batch
def _assert_using_config_xor_args(policy):
assert_if_key_in_dict_then_args_are_none(
- policy.config, "add_utilitarian_welfare", add_utilitarian_welfare)
+ policy.config, ADD_UTILITARIAN_WELFARE, add_utilitarian_welfare
+ )
assert_if_key_in_dict_then_args_are_none(
- policy.config, "add_inequity_aversion_welfare",
- add_inequity_aversion_welfare, inequity_aversion_alpha,
- inequity_aversion_beta, inequity_aversion_gamma,
- inequity_aversion_lambda)
+ policy.config,
+ ADD_INEQUITY_AVERSION_WELFARE,
+ add_inequity_aversion_welfare,
+ inequity_aversion_alpha,
+ inequity_aversion_beta,
+ inequity_aversion_gamma,
+ inequity_aversion_lambda,
+ )
assert_if_key_in_dict_then_args_are_none(
- policy.config, "add_nash_welfare", add_nash_welfare)
+ policy.config, ADD_NASH_WELFARE, add_nash_welfare
+ )
assert_if_key_in_dict_then_args_are_none(
- policy.config, "add_egalitarian_welfare", add_egalitarian_welfare)
+ policy.config, ADD_EGALITARIAN_WELFARE, add_egalitarian_welfare
+ )
assert_if_key_in_dict_then_args_are_none(
- policy.config, "add_opponent_action", add_opponent_action)
+ policy.config, ADD_OPPONENT_ACTION, add_opponent_action
+ )
assert_if_key_in_dict_then_args_are_none(
- policy.config, "add_opponent_neg_reward", add_opponent_neg_reward)
+ policy.config, ADD_OPPONENT_NEG_REWARD, add_opponent_neg_reward
+ )
def _read_parameters_from_config_default_to_args(policy):
add_utilitarian_w = read_from_dict_default_to_args(
- policy.config, ADD_UTILITARIAN_WELFARE, add_utilitarian_welfare)
- add_ia_w, ia_alpha, ia_beta, ia_gamma, ia_lambda = \
- read_from_dict_default_to_args(
- policy.config, ADD_INEQUITY_AVERSION_WELFARE,
- add_inequity_aversion_welfare, inequity_aversion_alpha,
- inequity_aversion_beta, inequity_aversion_gamma,
- inequity_aversion_lambda)
+ policy.config, ADD_UTILITARIAN_WELFARE, add_utilitarian_welfare
+ )
+ (
+ add_ia_w,
+ ia_alpha,
+ ia_beta,
+ ia_gamma,
+ ia_lambda,
+ ) = read_from_dict_default_to_args(
+ policy.config,
+ ADD_INEQUITY_AVERSION_WELFARE,
+ add_inequity_aversion_welfare,
+ inequity_aversion_alpha,
+ inequity_aversion_beta,
+ inequity_aversion_gamma,
+ inequity_aversion_lambda,
+ )
add_nash_w = read_from_dict_default_to_args(
- policy.config, ADD_NASH_WELFARE, add_nash_welfare)
+ policy.config, ADD_NASH_WELFARE, add_nash_welfare
+ )
add_egalitarian_w = read_from_dict_default_to_args(
- policy.config, ADD_EGALITARIAN_WELFARE, add_egalitarian_welfare)
+ policy.config, ADD_EGALITARIAN_WELFARE, add_egalitarian_welfare
+ )
add_opponent_a = read_from_dict_default_to_args(
- policy.config, ADD_OPPONENT_ACTION, add_opponent_action)
+ policy.config, ADD_OPPONENT_ACTION, add_opponent_action
+ )
add_opponent_neg_r = read_from_dict_default_to_args(
- policy.config, ADD_OPPONENT_NEG_REWARD, add_opponent_neg_reward)
-
- return add_utilitarian_w, \
- add_ia_w, ia_alpha, ia_beta, ia_gamma, ia_lambda, \
- add_nash_w, add_egalitarian_w, \
- add_opponent_a, add_opponent_neg_r
+ policy.config, ADD_OPPONENT_NEG_REWARD, add_opponent_neg_reward
+ )
+
+ return (
+ add_utilitarian_w,
+ add_ia_w,
+ ia_alpha,
+ ia_beta,
+ ia_gamma,
+ ia_lambda,
+ add_nash_w,
+ add_egalitarian_w,
+ add_opponent_a,
+ add_opponent_neg_r,
+ )
+
+ def _get_values_from_any_scheduler(parameters, policy):
+ logger.debug("_get_values_from_any_scheduler")
+ new_parameters = []
+ for param in parameters:
+ if isinstance(param, Schedule):
+ value_from_scheduler = param.value(policy.global_timestep)
+ new_parameters.append(value_from_scheduler)
+ else:
+ new_parameters.append(param)
+ return new_parameters
def _add_welfare_to_own_batch(
- sample_batch, other_agent_batches, episode, policy, *parameters):
-
- add_utilitarian_w, \
- add_ia_w, ia_alpha, ia_beta, ia_gamma, ia_lambda, \
- add_nash_w, add_egalitarian_w, \
- add_opponent_a, add_opponent_neg_r = parameters
-
- assert len(set(sample_batch[sample_batch.EPS_ID])) == 1, \
- "design to work on one complete episode"
- assert sample_batch[sample_batch.DONES][-1], \
- "design to work on one complete episode, dones: " \
- f"{sample_batch[sample_batch.DONES]}"
+ sample_batch, other_agent_batches, episode, policy, *parameters
+ ):
+
+ (
+ add_utilitarian_w,
+ add_ia_w,
+ ia_alpha,
+ ia_beta,
+ ia_gamma,
+ ia_lambda,
+ add_nash_w,
+ add_egalitarian_w,
+ add_opponent_a,
+ add_opponent_neg_r,
+ ) = parameters
+
+ _assert_working_on_one_full_epi(sample_batch)
if add_utilitarian_w:
- logger.debug(f"add utilitarian welfare to batch of policy"
- f" {policy}")
+ logger.debug(
+ f"add utilitarian welfare to batch of policy" f" {policy}"
+ )
opp_batches = [v[1] for v in other_agent_batches.values()]
sample_batch = _add_utilitarian_welfare_to_batch(
- sample_batch, opp_batches, policy)
+ sample_batch, opp_batches, policy
+ )
if add_ia_w:
- logger.debug(f"add inequity aversion welfare to batch of policy"
- f" {policy}")
+ logger.debug(
+ f"add inequity aversion welfare to batch of policy"
+ f" {policy}"
+ )
_assert_two_players_env(other_agent_batches)
opp_batch = _get_opp_batch(other_agent_batches)
sample_batch = _add_inequity_aversion_welfare_to_batch(
- sample_batch, opp_batch,
+ sample_batch,
+ opp_batch,
alpha=ia_alpha,
beta=ia_beta,
gamma=ia_gamma,
lambda_=ia_lambda,
- policy=policy)
+ policy=policy,
+ )
if add_nash_w:
_assert_two_players_env(other_agent_batches)
opp_batch = _get_opp_batch(other_agent_batches)
sample_batch = _add_nash_welfare_to_batch(
- sample_batch, opp_batch, policy)
+ sample_batch, opp_batch, policy
+ )
if add_egalitarian_w:
_assert_two_players_env(other_agent_batches)
opp_batch = _get_opp_batch(other_agent_batches)
sample_batch = _add_egalitarian_welfare_to_batch(
- sample_batch, opp_batch, policy)
+ sample_batch, opp_batch, policy
+ )
if add_opponent_a:
_assert_two_players_env(other_agent_batches)
opp_batch = _get_opp_batch(other_agent_batches)
sample_batch = _add_opponent_action_to_batch(
- sample_batch, opp_batch, policy)
+ sample_batch, opp_batch, policy
+ )
if add_opponent_neg_r:
_assert_two_players_env(other_agent_batches)
opp_batch = _get_opp_batch(other_agent_batches)
sample_batch = _add_opponent_neg_reward_to_batch(
- sample_batch, opp_batch, policy)
+ sample_batch, opp_batch, policy
+ )
return sample_batch
return postprocess_fn
-def _call_list_of_additional_fn(additional_fn,
- sample_batch, other_agent_batches, episode,
- policy):
+def _assert_working_on_one_full_epi(sample_batch):
+ assert (
+ len(set(sample_batch[sample_batch.EPS_ID])) == 1
+ ), "designed to work on one complete episode"
+ assert (
+ not any(sample_batch[sample_batch.DONES][:-1])
+ or sample_batch[sample_batch.DONES][-1]
+ ), (
+ "welfare postprocessing is designed to work on one complete episode, "
+ f"dones: {sample_batch[sample_batch.DONES]}"
+ )
+
+
+def _call_list_of_additional_fn(
+ additional_fn, sample_batch, other_agent_batches, episode, policy
+):
for postprocessing_function in additional_fn:
sample_batch = postprocessing_function(
- sample_batch, other_agent_batches, episode, policy)
+ sample_batch, other_agent_batches, episode, policy
+ )
return sample_batch
@@ -208,19 +298,20 @@ def _get_opp_batch(other_agent_batches):
def _add_utilitarian_welfare_to_batch(
- sample_batch: SampleBatch,
- opp_ag_batchs: List[SampleBatch],
- policy=None
+ sample_batch: SampleBatch, opp_ag_batchs: List[SampleBatch], policy=None
) -> SampleBatch:
- all_batchs_rewards = ([sample_batch[sample_batch.REWARDS]] +
- [opp_batch[opp_batch.REWARDS] for opp_batch in
- opp_ag_batchs])
- sample_batch.data[WELFARE_UTILITARIAN] = np.array(
- [sum(reward_points) for reward_points in zip(*all_batchs_rewards)])
-
- _ = _log_in_policy(np.sum(sample_batch.data[WELFARE_UTILITARIAN]),
- f"sum_over_epi_{WELFARE_UTILITARIAN}",
- policy)
+ all_batchs_rewards = [sample_batch[sample_batch.REWARDS]] + [
+ opp_batch[opp_batch.REWARDS] for opp_batch in opp_ag_batchs
+ ]
+ sample_batch[WELFARE_UTILITARIAN] = np.array(
+ [sum(reward_points) for reward_points in zip(*all_batchs_rewards)]
+ )
+
+ _ = _log_in_policy(
+ np.sum(sample_batch[WELFARE_UTILITARIAN]),
+ f"sum_over_epi_{WELFARE_UTILITARIAN}",
+ policy,
+ )
return sample_batch
@@ -234,31 +325,40 @@ def _log_in_policy(value, name_value, policy=None):
def _add_opponent_action_to_batch(
- sample_batch: SampleBatch,
- opp_ag_batch: SampleBatch,
- policy=None) -> SampleBatch:
- sample_batch.data[OPPONENT_ACTIONS] = opp_ag_batch[opp_ag_batch.ACTIONS]
- _ = _log_in_policy(sample_batch.data[OPPONENT_ACTIONS][-1],
- f"last_{OPPONENT_ACTIONS}", policy)
+ sample_batch: SampleBatch, opp_ag_batch: SampleBatch, policy=None
+) -> SampleBatch:
+ sample_batch[OPPONENT_ACTIONS] = opp_ag_batch[opp_ag_batch.ACTIONS]
+ _ = _log_in_policy(
+ sample_batch[OPPONENT_ACTIONS][-1],
+ f"last_{OPPONENT_ACTIONS}",
+ policy,
+ )
return sample_batch
def _add_opponent_neg_reward_to_batch(
- sample_batch: SampleBatch,
- opp_ag_batch: SampleBatch,
- policy=None) -> SampleBatch:
- sample_batch.data[OPPONENT_NEGATIVE_REWARD] = np.array(
- [- opp_r for opp_r in opp_ag_batch[opp_ag_batch.REWARDS]])
- _ = _log_in_policy(np.sum(sample_batch.data[OPPONENT_NEGATIVE_REWARD]),
- f"sum_over_epi_{OPPONENT_NEGATIVE_REWARD}", policy)
+ sample_batch: SampleBatch, opp_ag_batch: SampleBatch, policy=None
+) -> SampleBatch:
+ sample_batch[OPPONENT_NEGATIVE_REWARD] = np.array(
+ [-opp_r for opp_r in opp_ag_batch[opp_ag_batch.REWARDS]]
+ )
+ _ = _log_in_policy(
+ np.sum(sample_batch[OPPONENT_NEGATIVE_REWARD]),
+ f"sum_over_epi_{OPPONENT_NEGATIVE_REWARD}",
+ policy,
+ )
return sample_batch
def _add_inequity_aversion_welfare_to_batch(
- sample_batch: SampleBatch, opp_ag_batch: SampleBatch,
- alpha: float, beta: float, gamma: float,
- lambda_: float,
- policy=None) -> SampleBatch:
+ sample_batch: SampleBatch,
+ opp_ag_batch: SampleBatch,
+ alpha: float,
+ beta: float,
+ gamma: float,
+ lambda_: float,
+ policy=None,
+) -> SampleBatch:
"""
:param sample_batch: SampleBatch to mutate
:param opp_ag_batchs:
@@ -274,90 +374,115 @@ def _add_inequity_aversion_welfare_to_batch(
opp_rewards = np.array(opp_ag_batch[opp_ag_batch.REWARDS])
own_rewards = np.flip(own_rewards)
opp_rewards = np.flip(opp_rewards)
- delta = (discount(own_rewards, gamma * lambda_) -
- discount(opp_rewards, gamma * lambda_))
+ delta = discount_cumsum(own_rewards, gamma * lambda_) - discount_cumsum(
+ opp_rewards, gamma * lambda_
+ )
delta = np.flip(delta)
disvalue_lower_than_opp = alpha * (-delta)
disvalue_higher_than_opp = beta * delta
disvalue_lower_than_opp[disvalue_lower_than_opp < 0] = 0
disvalue_higher_than_opp[disvalue_higher_than_opp < 0] = 0
- welfare = sample_batch[sample_batch.REWARDS] - \
- disvalue_lower_than_opp - disvalue_higher_than_opp
+ welfare = (
+ sample_batch[sample_batch.REWARDS]
+ - disvalue_lower_than_opp
+ - disvalue_higher_than_opp
+ )
- sample_batch.data[WELFARE_INEQUITY_AVERSION] = welfare
+ sample_batch[WELFARE_INEQUITY_AVERSION] = welfare
policy = _log_in_policy(
- np.sum(sample_batch.data[WELFARE_INEQUITY_AVERSION]),
- f"sum_over_epi_{WELFARE_INEQUITY_AVERSION}", policy)
+ np.sum(sample_batch[WELFARE_INEQUITY_AVERSION]),
+ f"sum_over_epi_{WELFARE_INEQUITY_AVERSION}",
+ policy,
+ )
return sample_batch
def _add_nash_welfare_to_batch(
- sample_batch: SampleBatch, opp_ag_batch: SampleBatch,
- policy=None) -> SampleBatch:
+ sample_batch: SampleBatch, opp_ag_batch: SampleBatch, policy=None
+) -> SampleBatch:
own_rewards = np.array(opp_ag_batch[opp_ag_batch.REWARDS])
opp_rewards = np.array(opp_ag_batch[opp_ag_batch.REWARDS])
own_rewards_under_defection = np.array(
- opp_ag_batch.data[REWARDS_UNDER_DEFECTION])
+ opp_ag_batch[REWARDS_UNDER_DEFECTION]
+ )
opp_rewards_under_defection = np.array(
- opp_ag_batch.data[REWARDS_UNDER_DEFECTION])
+ opp_ag_batch[REWARDS_UNDER_DEFECTION]
+ )
- own_delta = (sum(own_rewards) - sum(own_rewards_under_defection))
- opp_delta = (sum(opp_rewards) - sum(opp_rewards_under_defection))
+ own_delta = sum(own_rewards) - sum(own_rewards_under_defection)
+ opp_delta = sum(opp_rewards) - sum(opp_rewards_under_defection)
nash_welfare = own_delta * opp_delta
- sample_batch.data[WELFARE_NASH] = ([0.0] * (
- len(sample_batch[sample_batch.REWARDS]) - 1)) + [nash_welfare]
- policy = _log_in_policy(np.sum(sample_batch.data[WELFARE_NASH]),
- f"sum_over_epi_{WELFARE_NASH}", policy)
+ sample_batch[WELFARE_NASH] = (
+ [0.0] * (len(sample_batch[sample_batch.REWARDS]) - 1)
+ ) + [nash_welfare]
+ policy = _log_in_policy(
+ np.sum(sample_batch[WELFARE_NASH]),
+ f"sum_over_epi_{WELFARE_NASH}",
+ policy,
+ )
return sample_batch
def _add_egalitarian_welfare_to_batch(
- sample_batch: SampleBatch, opp_ag_batch: SampleBatch,
- policy=None) -> SampleBatch:
+ sample_batch: SampleBatch, opp_ag_batch: SampleBatch, policy=None
+) -> SampleBatch:
own_rewards = np.array(opp_ag_batch[opp_ag_batch.REWARDS])
opp_rewards = np.array(opp_ag_batch[opp_ag_batch.REWARDS])
own_rewards_under_defection = np.array(
- opp_ag_batch.data[REWARDS_UNDER_DEFECTION])
+ opp_ag_batch[REWARDS_UNDER_DEFECTION]
+ )
opp_rewards_under_defection = np.array(
- opp_ag_batch.data[REWARDS_UNDER_DEFECTION])
+ opp_ag_batch[REWARDS_UNDER_DEFECTION]
+ )
- own_delta = (sum(own_rewards) - sum(own_rewards_under_defection))
- opp_delta = (sum(opp_rewards) - sum(opp_rewards_under_defection))
+ own_delta = sum(own_rewards) - sum(own_rewards_under_defection)
+ opp_delta = sum(opp_rewards) - sum(opp_rewards_under_defection)
egalitarian_welfare = min(own_delta, opp_delta)
- sample_batch.data[WELFARE_EGALITARIAN] = ([0.0] * (
- len(sample_batch[sample_batch.REWARDS]) - 1)) + [
- egalitarian_welfare]
- policy = _log_in_policy(np.sum(sample_batch.data[WELFARE_EGALITARIAN]),
- f"sum_over_epi_{WELFARE_EGALITARIAN}",
- policy)
+ sample_batch[WELFARE_EGALITARIAN] = (
+ [0.0] * (len(sample_batch[sample_batch.REWARDS]) - 1)
+ ) + [egalitarian_welfare]
+ policy = _log_in_policy(
+ np.sum(sample_batch[WELFARE_EGALITARIAN]),
+ f"sum_over_epi_{WELFARE_EGALITARIAN}",
+ policy,
+ )
return sample_batch
class OverwriteRewardWtWelfareCallback(DefaultCallbacks):
-
def on_postprocess_trajectory(
- self, *, worker: "RolloutWorker", episode: MultiAgentEpisode,
- agent_id: AgentID, policy_id: PolicyID,
- policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
- original_batches: Dict[AgentID, SampleBatch], **kwargs):
-
- assert sum([k in WELFARES for k in
- postprocessed_batch.data.keys()]) <= 1, \
- "only one welfare must be available"
+ self,
+ *,
+ worker: "RolloutWorker",
+ episode: MultiAgentEpisode,
+ agent_id: AgentID,
+ policy_id: PolicyID,
+ policies: Dict[PolicyID, Policy],
+ postprocessed_batch: SampleBatch,
+ original_batches: Dict[AgentID, SampleBatch],
+ **kwargs,
+ ):
+
+ assert (
+ sum([k in WELFARES for k in postprocessed_batch.keys()]) <= 1
+ ), "only one welfare must be available"
for welfare_key in WELFARES:
- if welfare_key in postprocessed_batch.data.keys():
- postprocessed_batch[postprocessed_batch.REWARDS] = \
- postprocessed_batch.data[welfare_key]
- msg = f"Overwrite the reward of agent_id {agent_id} " \
- f"with the value from the" \
- f" welfare_key {welfare_key}"
+ if welfare_key in postprocessed_batch.keys():
+ postprocessed_batch[
+ postprocessed_batch.REWARDS
+ ] = postprocessed_batch[welfare_key]
+ msg = (
+ f"Overwrite the reward of agent_id {agent_id} "
+ f"with the value from the"
+ f" welfare_key {welfare_key}"
+ )
print(msg)
logger.debug(msg)
break
@@ -366,7 +491,8 @@ def on_postprocess_trajectory(
def apply_preprocessors(worker, raw_observation, policy_id):
- prep_obs = _get_or_raise(
- worker.preprocessors, policy_id).transform(raw_observation)
+ prep_obs = _get_or_raise(worker.preprocessors, policy_id).transform(
+ raw_observation
+ )
filtered_obs = _get_or_raise(worker.filters, policy_id)(prep_obs)
return filtered_obs
diff --git a/marltoolbox/utils/restore.py b/marltoolbox/utils/restore.py
index 519925f..75e4a3b 100644
--- a/marltoolbox/utils/restore.py
+++ b/marltoolbox/utils/restore.py
@@ -1,13 +1,18 @@
import logging
import os
import pickle
+from typing import List
+
+from marltoolbox import utils
+from marltoolbox.utils import path
+from ray.tune.analysis import ExperimentAnalysis
logger = logging.getLogger(__name__)
LOAD_FROM_CONFIG_KEY = "checkpoint_to_load_from"
-def after_init_load_policy_checkpoint(
+def before_loss_init_load_policy_checkpoint(
policy, observation_space=None, action_space=None, trainer_config=None
):
"""
@@ -27,7 +32,7 @@ def after_init_load_policy_checkpoint(
Example: determining the checkpoint to load conditional on the current seed
(when doing a grid_search over random seeds and with a multistage training)
"""
- checkpoint_path, policy_id = policy.config.pop(
+ checkpoint_path, policy_id = policy.config.get(
LOAD_FROM_CONFIG_KEY, (None, None)
)
@@ -40,8 +45,7 @@ def after_init_load_policy_checkpoint(
f"marltoolbox restore: checkpoint found for policy_id: "
f"{policy_id}"
)
- logger.info(msg)
- print(msg)
+ logger.debug(msg)
else:
msg = (
f"marltoolbox restore: NO checkpoint found for policy_id:"
@@ -49,7 +53,6 @@ def after_init_load_policy_checkpoint(
f"Not found under the config key: {LOAD_FROM_CONFIG_KEY}"
)
logger.warning(msg)
- print(msg)
def load_one_policy_checkpoint(
@@ -75,7 +78,7 @@ def load_one_policy_checkpoint(
policy.load_checkpoint(checkpoint_tuple=(checkpoint_path, policy_id))
else:
checkpoint_path = os.path.expanduser(checkpoint_path)
- logger.info(f"checkpoint_path {checkpoint_path}")
+ logger.debug(f"checkpoint_path {checkpoint_path}")
checkpoint = pickle.load(open(checkpoint_path, "rb"))
assert "worker" in checkpoint.keys()
assert "optimizer" not in checkpoint.keys()
@@ -86,7 +89,7 @@ def load_one_policy_checkpoint(
found_policy_id = False
for p_id, state in objs["state"].items():
if p_id == policy_id:
- print(
+ logger.debug(
f"going to load policy {policy_id} "
f"from checkpoint {checkpoint_path}"
)
@@ -94,8 +97,114 @@ def load_one_policy_checkpoint(
found_policy_id = True
break
if not found_policy_id:
- print(
+ logger.debug(
f"policy_id {policy_id} not in "
f'checkpoint["worker"]["state"].keys() '
f'{objs["state"].keys()}'
)
+
+
+def extract_checkpoints_from_tune_analysis(
+ tune_experiment_analysis: ExperimentAnalysis,
+) -> List[str]:
+ """
+ Extract all the best checkpoints from a tune analysis object. This tune
+ analysis can contains several trials. Each trial can contains several
+ checkpoitn, only the best checkpoint per trial is returned.
+
+ :param tune_experiment_analysis:
+ :return: list of all the unique best checkpoints for each trials in the
+ tune analysis.
+ """
+ logger.info("start extract_checkpoints")
+
+ for trial in tune_experiment_analysis.trials:
+ checkpoints = tune_experiment_analysis.get_trial_checkpoints_paths(
+ trial, tune_experiment_analysis.default_metric
+ )
+ assert len(checkpoints) > 0
+
+ all_best_checkpoints_per_trial = [
+ tune_experiment_analysis.get_best_checkpoint(
+ trial,
+ metric=tune_experiment_analysis.default_metric,
+ mode=tune_experiment_analysis.default_mode,
+ )
+ for trial in tune_experiment_analysis.trials
+ ]
+
+ for checkpoint in all_best_checkpoints_per_trial:
+ assert checkpoint is not None
+
+ logger.info("end extract_checkpoints")
+ return all_best_checkpoints_per_trial
+
+
+def get_checkpoint_for_each_replicates(
+ all_replicates_save_dir: List[str],
+) -> List[str]:
+ """
+ Get the list of paths to the checkpoint files inside an experiment dir of
+ RLLib/Tune (which can contains several trials).
+ Works for an experiment with trials containing an unique checkpoint.
+
+ :param all_replicates_save_dir: trial dir
+ :return: list of paths to checkpoint files
+ """
+ ckpt_dir_per_replicate = []
+ for replicate_dir_path in all_replicates_save_dir:
+ ckpt_dir_path = get_ckpt_dir_for_one_replicate(replicate_dir_path)
+ ckpt_path = get_ckpt_from_ckpt_dir(ckpt_dir_path)
+ ckpt_dir_per_replicate.append(ckpt_path)
+ return ckpt_dir_per_replicate
+
+
+def get_ckpt_dir_for_one_replicate(replicate_dir_path: str) -> str:
+ """
+ Get the path to the unique checkpoint dir inside a trial dir of RLLib/Tune.
+
+ :param replicate_dir_path: trial dir
+ :return: path to checkpoint dir
+ """
+ partialy_filtered_ckpt_dir = (
+ utils.path.get_children_paths_wt_selecting_filter(
+ replicate_dir_path, _filter="checkpoint_"
+ )
+ )
+ ckpt_dir = [
+ file_path
+ for file_path in partialy_filtered_ckpt_dir
+ if ".is_checkpoint" not in file_path
+ ]
+ assert len(ckpt_dir) == 1, f"{ckpt_dir}"
+ return ckpt_dir[0]
+
+
+def get_ckpt_from_ckpt_dir(ckpt_dir_path: str) -> str:
+ """
+ Get the path to the unique checkpoint file inside a checkpoint dir of
+ RLLib/Tune
+ :param ckpt_dir_path: checkpoint dir
+ :return: path to checkpoint file
+ """
+ partialy_filtered_ckpt_path = (
+ utils.path.get_children_paths_wt_discarding_filter(
+ ckpt_dir_path, _filter="tune_metadata"
+ )
+ )
+ filters = [
+ # For Tune/RLLib
+ ".is_checkpoint",
+ # For TensorFlow
+ "ckpt.index",
+ "ckpt.data-",
+ "ckpt.meta",
+ ".json",
+ ]
+ ckpt_path = filter(
+ lambda el: all(filter_ not in el for filter_ in filters),
+ partialy_filtered_ckpt_path,
+ )
+ ckpt_path = list(ckpt_path)
+ assert len(ckpt_path) == 1, f"{ckpt_path}"
+ return ckpt_path[0]
diff --git a/marltoolbox/utils/rollout.py b/marltoolbox/utils/rollout.py
index 84be755..6c8708e 100644
--- a/marltoolbox/utils/rollout.py
+++ b/marltoolbox/utils/rollout.py
@@ -5,18 +5,24 @@
import collections
import copy
+import logging
from typing import List
from gym import wrappers as gym_wrappers
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
-from ray.rllib.rollout import DefaultMapping, default_policy_agent_mapping, \
- RolloutSaver
+from ray.rllib.rollout import (
+ DefaultMapping,
+ default_policy_agent_mapping,
+ RolloutSaver,
+)
from ray.rllib.utils.framework import TensorStructType
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
from ray.rllib.utils.typing import EnvInfoDict, PolicyID
+logger = logging.getLogger(__name__)
+
class RolloutManager(RolloutSaver):
"""
@@ -43,26 +49,29 @@ def append_step(self, obs, action, next_obs, reward, done, info):
"""Add a step to the current rollout, if we are saving them"""
if self._save_info:
self._current_rollout.append(
- [obs, action, next_obs, reward, done, info])
+ [obs, action, next_obs, reward, done, info]
+ )
else:
- self._current_rollout.append(
- [obs, action, next_obs, reward, done])
+ self._current_rollout.append([obs, action, next_obs, reward, done])
self._total_steps += 1
-def internal_rollout(worker,
- num_steps,
- policy_map=None,
- policy_agent_mapping=None,
- reset_env_before=True,
- num_episodes=0,
- last_obs=None,
- saver=None,
- no_render=True,
- video_dir=None,
- seed=None,
- explore=None,
- ):
+def internal_rollout(
+ worker,
+ num_steps,
+ policy_map=None,
+ policy_agent_mapping=None,
+ reset_env_before=True,
+ num_episodes=0,
+ last_obs=None,
+ saver=None,
+ no_render=True,
+ video_dir=None,
+ seed=None,
+ explore=None,
+ last_rnn_states=None,
+ base_env=None,
+):
"""
Can perform rollouts on the environment from inside a worker_rollout or
from a policy. Can perform rollouts during the evaluation rollouts ran
@@ -85,6 +94,7 @@ def internal_rollout(worker,
:param video_dir: (optional)
:param seed: (optional) random seed to set for the environment by calling
env.seed(seed)
+ :param last_rnn_states: map of policy_id to rnn_states
:return: an instance of a RolloutManager, which contains the data about
the rollouts performed
"""
@@ -95,15 +105,21 @@ def internal_rollout(worker,
if saver is None:
saver = RolloutManager()
- env = copy.deepcopy(worker.env)
- if hasattr(env, "seed") and callable(env.seed):
- env.seed(seed)
+ if base_env is None:
+ env = worker.env
+ else:
+ env = base_env.get_unwrapped()[0]
+
+ # if hasattr(env, "seed") and callable(env.seed):
+ # env.seed(seed)
+ env = copy.deepcopy(env)
multiagent = isinstance(env, MultiAgentEnv)
if policy_agent_mapping is None:
if worker.multiagent:
policy_agent_mapping = worker.policy_config["multiagent"][
- "policy_mapping_fn"]
+ "policy_mapping_fn"
+ ]
else:
policy_agent_mapping = default_policy_agent_mapping
@@ -123,40 +139,54 @@ def internal_rollout(worker,
env=env,
directory=video_dir,
video_callable=lambda x: True,
- force=True)
+ force=True,
+ )
random_policy_id = list(policy_map.keys())[0]
virtual_global_timestep = worker.get_policy(
- random_policy_id).global_timestep
+ random_policy_id
+ ).global_timestep
steps = 0
episodes = 0
while _keep_going(steps, num_steps, episodes, num_episodes):
+ # logger.info(f"Starting epsiode {episodes} in rollout")
+ # print(f"Starting epsiode {episodes} in rollout")
mapping_cache = {} # in case policy_agent_mapping is stochastic
saver.begin_rollout()
- if reset_env_before or episodes > 0:
- obs = env.reset()
- else:
- obs = last_obs
- agent_states = DefaultMapping(
- lambda agent_id_: state_init[mapping_cache[agent_id_]])
+ obs, agent_states = _get_first_obs(
+ env,
+ reset_env_before,
+ episodes,
+ last_obs,
+ mapping_cache,
+ state_init,
+ last_rnn_states,
+ )
prev_actions = DefaultMapping(
- lambda agent_id_: action_init[mapping_cache[agent_id_]])
- prev_rewards = collections.defaultdict(lambda: 0.)
+ lambda agent_id_: action_init[mapping_cache[agent_id_]]
+ )
+ prev_rewards = collections.defaultdict(lambda: 0.0)
done = False
reward_total = 0.0
- while not done and _keep_going(steps, num_steps, episodes,
- num_episodes):
-
+ while not done and _keep_going(
+ steps, num_steps, episodes, num_episodes
+ ):
multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs}
action_dict = {}
virtual_global_timestep += 1
for agent_id, a_obs in multi_obs.items():
if a_obs is not None:
policy_id = mapping_cache.setdefault(
- agent_id, policy_agent_mapping(agent_id))
+ agent_id, policy_agent_mapping(agent_id)
+ )
p_use_lstm = use_lstm[policy_id]
- # print("rollout")
+ # print("p_use_lstm", p_use_lstm)
+ # print(
+ # agent_id,
+ # "agent_states[agent_id]",
+ # agent_states[agent_id],
+ # )
if p_use_lstm:
a_action, p_state, _ = _worker_compute_action(
worker,
@@ -166,8 +196,12 @@ def internal_rollout(worker,
prev_action=prev_actions[agent_id],
prev_reward=prev_rewards[agent_id],
policy_id=policy_id,
- explore=explore
+ explore=explore,
)
+ # print(
+ # "after rollout _worker_compute_action p_state",
+ # p_state,
+ # )
agent_states[agent_id] = p_state
else:
a_action = _worker_compute_action(
@@ -177,7 +211,7 @@ def internal_rollout(worker,
prev_action=prev_actions[agent_id],
prev_reward=prev_rewards[agent_id],
policy_id=policy_id,
- explore=explore
+ explore=explore,
)
a_action = flatten_to_single_ndarray(a_action)
action_dict[agent_id] = a_action
@@ -196,7 +230,8 @@ def internal_rollout(worker,
if multiagent:
done = done["__all__"]
reward_total += sum(
- r for r in reward.values() if r is not None)
+ r for r in reward.values() if r is not None
+ )
else:
reward_total += reward
if not no_render:
@@ -228,24 +263,55 @@ def _keep_going(steps, num_steps, episodes, num_episodes):
return True
-def _worker_compute_action(worker, timestep,
- observation: TensorStructType,
- state: List[TensorStructType] = None,
- prev_action: TensorStructType = None,
- prev_reward: float = None,
- info: EnvInfoDict = None,
- policy_id: PolicyID = DEFAULT_POLICY_ID,
- full_fetch: bool = False,
- explore: bool = None) -> TensorStructType:
+def _get_first_obs(
+ env,
+ reset_env_before,
+ episodes,
+ last_obs,
+ mapping_cache,
+ state_init,
+ last_rnn_states,
+):
+ if reset_env_before or episodes > 0:
+ obs = env.reset()
+ agent_states = DefaultMapping(
+ lambda agent_id_: state_init[mapping_cache[agent_id_]]
+ )
+ else:
+ obs = last_obs
+ if last_rnn_states is not None:
+ agent_states = DefaultMapping(
+ lambda agent_id_: last_rnn_states[mapping_cache[agent_id_]]
+ )
+ else:
+ agent_states = DefaultMapping(
+ lambda agent_id_: state_init[mapping_cache[agent_id_]]
+ )
+ return obs, agent_states
+
+
+def _worker_compute_action(
+ worker,
+ timestep,
+ observation: TensorStructType,
+ state: List[TensorStructType] = None,
+ prev_action: TensorStructType = None,
+ prev_reward: float = None,
+ info: EnvInfoDict = None,
+ policy_id: PolicyID = DEFAULT_POLICY_ID,
+ full_fetch: bool = False,
+ explore: bool = None,
+) -> TensorStructType:
"""
Modified version of the Trainer compute_action method
"""
if state is None:
state = []
- preprocessed = worker.preprocessors[
- policy_id].transform(observation)
- filtered_obs = worker.filters[policy_id](
- preprocessed, update=False)
+ # Check the preprocessor and preprocess, if necessary.
+ pp = worker.preprocessors[policy_id]
+ if type(pp).__name__ != "NoPreprocessor":
+ observation = pp.transform(observation)
+ filtered_obs = worker.filters[policy_id](observation, update=False)
result = worker.get_policy(policy_id).compute_single_action(
filtered_obs,
state,
@@ -254,7 +320,8 @@ def _worker_compute_action(worker, timestep,
info,
clip_actions=worker.policy_config["clip_actions"],
explore=explore,
- timestep=timestep)
+ timestep=timestep,
+ )
if state or full_fetch:
return result
diff --git a/marltoolbox/utils/tune_analysis.py b/marltoolbox/utils/tune_analysis.py
new file mode 100644
index 0000000..82c283e
--- /dev/null
+++ b/marltoolbox/utils/tune_analysis.py
@@ -0,0 +1,144 @@
+from marltoolbox.utils.miscellaneous import move_to_key, logger
+
+
+def extract_value_from_last_training_iteration_for_each_trials(
+ tune_analysis,
+ metric="episode_reward_mean",
+):
+ metric_values = []
+ for trial in tune_analysis.trials:
+ last_results = trial.last_result
+ _, _, value, found = move_to_key(last_results, key=metric)
+ assert (
+ found
+ ), f"metric: {metric} not found in last_results: {last_results}"
+ metric_values.append(value)
+ return metric_values
+
+
+def extract_metrics_for_each_trials(
+ tune_analysis,
+ metric="episode_reward_mean",
+ metric_mode="avg",
+):
+ metric_values = []
+ for trial in tune_analysis.trials:
+ metric_values.append(trial.metric_analysis[metric][metric_mode])
+ return metric_values
+
+
+def check_learning_achieved(
+ tune_results,
+ metric="episode_reward_mean",
+ trial_idx=0,
+ max_: float = None,
+ min_: float = None,
+ equal_: float = None,
+):
+ assert max_ is not None or min_ is not None or equal_ is not None
+
+ last_results = tune_results.trials[trial_idx].last_result
+ _, _, value, found = move_to_key(last_results, key=metric)
+ assert (
+ found
+ ), f"metric {metric} not found inside last_results {last_results}"
+
+ msg = (
+ f"Trial {trial_idx} achieved "
+ f"{value}"
+ f" on metric {metric}. This is a success if the value is below"
+ f" {max_} or above {min_} or equal to {equal_}."
+ )
+
+ logger.info(msg)
+ print(msg)
+ if min_ is not None:
+ assert value >= min_, f"value {value} must be above min_ {min_}"
+ if max_ is not None:
+ assert value <= max_, f"value {value} must be below max_ {max_}"
+ if equal_ is not None:
+ assert value == equal_, (
+ f"value {value} must be equal to equal_ " f"{equal_}"
+ )
+
+
+def extract_config_values_from_tune_analysis(tune_experiment_analysis, key):
+ values = []
+ for trial in tune_experiment_analysis.trials:
+ dict_, k, current_value, found = move_to_key(trial.config, key)
+ if found:
+ values.append(current_value)
+ else:
+ values.append(None)
+ return values
+
+
+ABOVE = "above"
+EQUAL = "equal"
+BELOW = "below"
+FILTERING_MODES = (ABOVE, EQUAL, BELOW)
+
+RLLIB_METRICS_MODES = (
+ "avg",
+ "min",
+ "max",
+ "last",
+ "last-5-avg",
+ "last-10-avg",
+)
+
+
+def filter_trials(
+ experiement_analysis,
+ metric,
+ metric_threshold: float,
+ metric_mode="last-5-avg",
+ threshold_mode=ABOVE,
+):
+ """
+ Filter trials of an ExperimentAnalysis
+
+ :param experiement_analysis:
+ :param metric:
+ :param metric_threshold:
+ :param metric_mode:
+ :param threshold_mode:
+ :return:
+ """
+ assert threshold_mode in FILTERING_MODES, (
+ f"threshold_mode {threshold_mode} " f"must be in {FILTERING_MODES}"
+ )
+ assert metric_mode in RLLIB_METRICS_MODES
+ print(
+ "Before trial filtering:", len(experiement_analysis.trials), "trials"
+ )
+ trials_filtered = []
+ print(
+ "metric_threshold", metric_threshold, "threshold_mode", threshold_mode
+ )
+ for trial_idx, trial in enumerate(experiement_analysis.trials):
+ available_metrics = trial.metric_analysis
+ try:
+ metric_value = available_metrics[metric][metric_mode]
+ except KeyError:
+ raise KeyError(
+ f"failed to read metric key:{metric} in "
+ f"available_metrics:{available_metrics}"
+ )
+ print(
+ f"trial_idx {trial_idx} "
+ f"available_metrics[{metric}][{metric_mode}] "
+ f"{metric_value}"
+ )
+ if threshold_mode == ABOVE and metric_value > metric_threshold:
+ trials_filtered.append(trial)
+ elif threshold_mode == EQUAL and metric_value == metric_threshold:
+ trials_filtered.append(trial)
+ elif threshold_mode == BELOW and metric_value < metric_threshold:
+ trials_filtered.append(trial)
+ else:
+ print(f"filtering out trial {trial_idx}")
+
+ experiement_analysis.trials = trials_filtered
+ print("After trial filtering:", len(experiement_analysis.trials), "trials")
+ return experiement_analysis
diff --git a/setup.py b/setup.py
index 84d7b01..ae23115 100644
--- a/setup.py
+++ b/setup.py
@@ -15,7 +15,7 @@ def read(fname):
long_description=read("README.md"),
license="MIT",
install_requires=[
- "ray[rllib]==1.0.0",
+ "ray[rllib]>=1.2.0",
"gym==0.17.3",
"torch>=1.6.0,<=1.7.0",
"tensorboard==1.15.0",
diff --git a/tests/marltoolbox/algos/amTFT/test_amTFTRolloutsTorchPolicy.py b/tests/marltoolbox/algos/amTFT/test_amTFTRolloutsTorchPolicy.py
index ac1a3d4..7047c0e 100644
--- a/tests/marltoolbox/algos/amTFT/test_amTFTRolloutsTorchPolicy.py
+++ b/tests/marltoolbox/algos/amTFT/test_amTFTRolloutsTorchPolicy.py
@@ -18,10 +18,7 @@
from marltoolbox.envs.matrix_sequential_social_dilemma import (
IteratedPrisonersDilemma,
)
-from marltoolbox.experiments.rllib_api.amtft_various_env import (
- get_rllib_config,
- get_hyperparameters,
-)
+from marltoolbox.experiments.rllib_api import amtft_various_env
from marltoolbox.utils import postprocessing, log
from test_base_policy import init_amTFT, generate_fake_discrete_actions
@@ -46,14 +43,22 @@ def test_compute_actions_overwrite():
(fake_actions, fake_state_out, fake_extra_fetches),
(fake_actions_2nd, fake_state_out_2nd, fake_extra_fetches_2nd),
]
- actions, state_out, extra_fetches = am_tft_policy.compute_actions(
- observations[env.players_ids[0]]
+ actions, state_out, extra_fetches = am_tft_policy._compute_action_helper(
+ observations[env.players_ids[0]],
+ state_batches=None,
+ seq_lens=1,
+ explore=True,
+ timestep=0,
)
assert actions == fake_actions
assert state_out == fake_state_out
assert extra_fetches == fake_extra_fetches
- actions, state_out, extra_fetches = am_tft_policy.compute_actions(
- observations[env.players_ids[0]]
+ actions, state_out, extra_fetches = am_tft_policy._compute_action_helper(
+ observations[env.players_ids[0]],
+ state_batches=None,
+ seq_lens=1,
+ explore=True,
+ timestep=0,
)
assert actions == fake_actions_2nd
assert state_out == fake_state_out_2nd
@@ -65,28 +70,38 @@ def test__select_algo_to_use_in_eval():
policy_class=amTFT.AmTFTRolloutsTorchPolicy
)
- def assert_(working_state_idx, active_algo_idx):
+ def assert_active_algo_idx(working_state_idx, active_algo_idx):
am_tft_policy.working_state = base_policy.WORKING_STATES[
working_state_idx
]
- am_tft_policy._select_witch_algo_to_use()
- assert am_tft_policy.active_algo_idx == active_algo_idx
+ am_tft_policy._select_witch_algo_to_use(None)
+ assert (
+ am_tft_policy.active_algo_idx == active_algo_idx
+ ), f"{am_tft_policy.active_algo_idx} == {active_algo_idx}"
am_tft_policy.use_opponent_policies = False
am_tft_policy.n_steps_to_punish = 0
- assert_(working_state_idx=2, active_algo_idx=base.OWN_COOP_POLICY_IDX)
+ assert_active_algo_idx(
+ working_state_idx=2, active_algo_idx=base.OWN_COOP_POLICY_IDX
+ )
am_tft_policy.use_opponent_policies = False
am_tft_policy.n_steps_to_punish = 1
- assert_(working_state_idx=2, active_algo_idx=base.OWN_SELFISH_POLICY_IDX)
+ assert_active_algo_idx(
+ working_state_idx=2, active_algo_idx=base.OWN_SELFISH_POLICY_IDX
+ )
am_tft_policy.use_opponent_policies = True
am_tft_policy.performing_rollouts = True
am_tft_policy.n_steps_to_punish_opponent = 0
- assert_(working_state_idx=2, active_algo_idx=base.OPP_COOP_POLICY_IDX)
+ assert_active_algo_idx(
+ working_state_idx=2, active_algo_idx=base.OPP_COOP_POLICY_IDX
+ )
am_tft_policy.use_opponent_policies = True
am_tft_policy.performing_rollouts = True
am_tft_policy.n_steps_to_punish_opponent = 1
- assert_(working_state_idx=2, active_algo_idx=base.OPP_SELFISH_POLICY_IDX)
+ assert_active_algo_idx(
+ working_state_idx=2, active_algo_idx=base.OPP_SELFISH_POLICY_IDX
+ )
def test__duration_found_or_continue_search():
@@ -156,12 +171,22 @@ def step(self, actions: dict):
return observations, rewards, epi_is_done, info
-def make_FakePolicyWtDefinedActions(list_actions_to_play, ParentPolicyCLass):
+def make_fake_policy_class_wt_defined_actions(
+ list_actions_to_play, ParentPolicyCLass
+):
class FakePolicyWtDefinedActions(ParentPolicyCLass):
- def compute_actions(self, *args, **kwargs):
+ def _compute_action_helper(self, *args, **kwargs):
+ print("len", len(list_actions_to_play))
action = list_actions_to_play.pop(0)
return np.array([action]), [], {}
+ def _initialize_loss_from_dummy_batch(
+ self,
+ auto_remove_unneeded_view_reqs: bool = True,
+ stats_fn=None,
+ ) -> None:
+ pass
+
return FakePolicyWtDefinedActions
@@ -178,20 +203,23 @@ def init_worker(
debug = True
exp_name, _ = log.log_in_current_day_dir("testing")
- hparams = get_hyperparameters(
+ hparams = amtft_various_env.get_hyperparameters(
debug,
train_n_replicates,
filter_utilitarian=False,
env="IteratedPrisonersDilemma",
)
- _, _, rllib_config = get_rllib_config(
+ stop, env_config, rllib_config = amtft_various_env.get_rllib_config(
hparams, welfare_fn=postprocessing.WELFARE_UTILITARIAN
)
rllib_config["env"] = FakeEnvWtActionAsReward
rllib_config["env_config"]["max_steps"] = max_steps
- rllib_config["seed"] = int(time.time())
+ rllib_config = _remove_dynamic_values_from_config(
+ rllib_config, hparams, env_config, stop
+ )
+
for policy_id in FakeEnvWtActionAsReward({}).players_ids:
policy_to_modify = list(
rllib_config["multiagent"]["policies"][policy_id]
@@ -202,31 +230,36 @@ def init_worker(
if actions_list_0 is not None:
policy_to_modify[3]["nested_policies"][0][
"Policy_class"
- ] = make_FakePolicyWtDefinedActions(
+ ] = make_fake_policy_class_wt_defined_actions(
copy.deepcopy(actions_list_0), DEFAULT_NESTED_POLICY_COOP
)
if actions_list_1 is not None:
policy_to_modify[3]["nested_policies"][1][
"Policy_class"
- ] = make_FakePolicyWtDefinedActions(
+ ] = make_fake_policy_class_wt_defined_actions(
copy.deepcopy(actions_list_1), DEFAULT_NESTED_POLICY_SELFISH
)
if actions_list_2 is not None:
policy_to_modify[3]["nested_policies"][2][
"Policy_class"
- ] = make_FakePolicyWtDefinedActions(
+ ] = make_fake_policy_class_wt_defined_actions(
copy.deepcopy(actions_list_2), DEFAULT_NESTED_POLICY_COOP
)
if actions_list_3 is not None:
policy_to_modify[3]["nested_policies"][3][
"Policy_class"
- ] = make_FakePolicyWtDefinedActions(
+ ] = make_fake_policy_class_wt_defined_actions(
copy.deepcopy(actions_list_3), DEFAULT_NESTED_POLICY_SELFISH
)
rllib_config["multiagent"]["policies"][policy_id] = tuple(
policy_to_modify
)
+ rllib_config["exploration_config"]["temperature_schedule"] = rllib_config[
+ "exploration_config"
+ ]["temperature_schedule"].func(rllib_config)
+ import ray
+ ray.tune.sample
dqn_trainer = DQNTrainer(
rllib_config, logger_creator=_get_logger_creator(exp_name)
)
@@ -236,10 +269,35 @@ def init_worker(
am_tft_policy_col = worker.get_policy("player_col")
am_tft_policy_row.working_state = WORKING_STATES[2]
am_tft_policy_col.working_state = WORKING_STATES[2]
+ print("env setup")
return worker, am_tft_policy_row, am_tft_policy_col
+def _remove_dynamic_values_from_config(
+ rllib_config, hparams, env_config, stop
+):
+ rllib_config["seed"] = int(time.time())
+ rllib_config["learning_starts"] = int(
+ rllib_config["env_config"]["max_steps"]
+ * rllib_config["env_config"]["bs_epi_mul"]
+ )
+ rllib_config["buffer_size"] = int(
+ env_config["max_steps"]
+ * env_config["buf_frac"]
+ * stop["episodes_total"]
+ )
+ rllib_config["train_batch_size"] = int(
+ env_config["max_steps"] * env_config["bs_epi_mul"]
+ )
+ rllib_config["training_intensity"] = int(
+ rllib_config["num_envs_per_worker"]
+ * rllib_config["num_workers"]
+ * hparams["training_intensity"]
+ )
+ return rllib_config
+
+
def _get_logger_creator(exp_name):
logdir_prefix = exp_name + "/"
tail, head = os.path.split(exp_name)
@@ -266,7 +324,9 @@ def default_logger_creator(config):
def test__compute_debit_using_rollouts():
- def assert_(worker_, am_tft_policy, last_obs, opp_action, assert_debit):
+ def assert_debit_value_computed(
+ worker_, am_tft_policy, last_obs, opp_action, assert_debit
+ ):
worker_.foreach_env(lambda env: env.reset())
debit = am_tft_policy._compute_debit_using_rollouts(
last_obs, opp_action, worker_
@@ -291,14 +351,14 @@ def init_no_extra_reward(max_steps_):
worker, am_tft_policy_row, am_tft_policy_col = init_no_extra_reward(
max_steps
)
- assert_(
+ assert_debit_value_computed(
worker,
am_tft_policy_row,
{"player_row": 0, "player_col": 0},
opp_action=0,
assert_debit=0,
)
- assert_(
+ assert_debit_value_computed(
worker,
am_tft_policy_col,
{"player_row": 1, "player_col": 0},
@@ -309,14 +369,14 @@ def init_no_extra_reward(max_steps_):
worker, am_tft_policy_row, am_tft_policy_col = init_no_extra_reward(
max_steps
)
- assert_(
+ assert_debit_value_computed(
worker,
am_tft_policy_row,
{"player_row": 1, "player_col": 0},
opp_action=1,
assert_debit=1,
)
- assert_(
+ assert_debit_value_computed(
worker,
am_tft_policy_col,
{"player_row": 1, "player_col": 1},
@@ -342,14 +402,14 @@ def init_selfish_opp_advantaged(max_steps):
worker, am_tft_policy_row, am_tft_policy_col = init_selfish_opp_advantaged(
max_steps
)
- assert_(
+ assert_debit_value_computed(
worker,
am_tft_policy_row,
{"player_row": 0, "player_col": 0},
opp_action=0,
assert_debit=0,
)
- assert_(
+ assert_debit_value_computed(
worker,
am_tft_policy_col,
{"player_row": 1, "player_col": 0},
@@ -375,14 +435,14 @@ def init_coop_opp_advantaged(max_steps):
worker, am_tft_policy_row, am_tft_policy_col = init_coop_opp_advantaged(
max_steps
)
- assert_(
+ assert_debit_value_computed(
worker,
am_tft_policy_row,
{"player_row": 1, "player_col": 0},
opp_action=1,
assert_debit=0,
)
- assert_(
+ assert_debit_value_computed(
worker,
am_tft_policy_col,
{"player_row": 1, "player_col": 1},
diff --git a/tests/marltoolbox/algos/amTFT/test_base_policy.py b/tests/marltoolbox/algos/amTFT/test_base_policy.py
index 8d7143b..898d76f 100644
--- a/tests/marltoolbox/algos/amTFT/test_base_policy.py
+++ b/tests/marltoolbox/algos/amTFT/test_base_policy.py
@@ -30,31 +30,31 @@ def init_amTFT(
def test__select_witch_algo_to_use():
am_tft_policy, env = init_amTFT()
- def assert_(working_state_idx, active_algo_idx):
+ def assert_active_algo_idx(working_state_idx, active_algo_idx):
am_tft_policy.working_state = base_policy.WORKING_STATES[
working_state_idx
]
- am_tft_policy._select_witch_algo_to_use()
+ am_tft_policy._select_witch_algo_to_use(None)
assert am_tft_policy.active_algo_idx == active_algo_idx
- assert_(
+ assert_active_algo_idx(
working_state_idx=0, active_algo_idx=base_policy.OWN_COOP_POLICY_IDX
)
- assert_(
+ assert_active_algo_idx(
working_state_idx=1, active_algo_idx=base_policy.OWN_SELFISH_POLICY_IDX
)
am_tft_policy.n_steps_to_punish = 0
- assert_(
+ assert_active_algo_idx(
working_state_idx=2, active_algo_idx=base_policy.OWN_COOP_POLICY_IDX
)
am_tft_policy.n_steps_to_punish = 1
- assert_(
+ assert_active_algo_idx(
working_state_idx=2, active_algo_idx=base_policy.OWN_SELFISH_POLICY_IDX
)
- assert_(
+ assert_active_algo_idx(
working_state_idx=3, active_algo_idx=base_policy.OWN_SELFISH_POLICY_IDX
)
- assert_(
+ assert_active_algo_idx(
working_state_idx=4, active_algo_idx=base_policy.OWN_COOP_POLICY_IDX
)
@@ -123,21 +123,20 @@ def test_lr_update():
)
one_policy_batch = multiagent_batch.policy_batches[env.players_ids[0]]
- am_tft_policy.on_global_var_update({"timestep": 0})
- am_tft_policy.learn_on_batch(one_policy_batch)
- for algo in am_tft_policy.algorithms:
- assert algo.cur_lr == base_lr
- for opt in algo._optimizers:
- for p in opt.param_groups:
- assert p["lr"] == algo.cur_lr
+ def _assert_lr_equals(policy, lr):
+ for algo in policy.algorithms:
+ assert algo.cur_lr == lr
+ for opt in algo._optimizers:
+ for p in opt.param_groups:
+ assert p["lr"] == lr
- am_tft_policy.on_global_var_update({"timestep": interm_global_timestep})
- am_tft_policy.learn_on_batch(one_policy_batch)
- for algo in am_tft_policy.algorithms:
- assert algo.cur_lr == final_lr
- for opt in algo._optimizers:
- for p in opt.param_groups:
- assert p["lr"] == algo.cur_lr
+ def _fake_n_step_assert_lr(policy, n_step, lr):
+ policy.on_global_var_update({"timestep": n_step})
+ policy.learn_on_batch(one_policy_batch)
+ _assert_lr_equals(policy, lr)
+
+ _fake_n_step_assert_lr(am_tft_policy, 0, base_lr)
+ _fake_n_step_assert_lr(am_tft_policy, interm_global_timestep, final_lr)
def test__is_punishment_planned():
@@ -148,12 +147,21 @@ def test__is_punishment_planned():
assert am_tft_policy._is_punishment_planned()
+from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv
+
+
def test_on_episode_end():
am_tft_policy, env = init_amTFT(
{"working_state": base_policy.WORKING_STATES[2]}
)
+ base_env = _MultiAgentEnvToBaseEnv(
+ make_env=lambda _: env, existing_envs=[], num_envs=1
+ )
am_tft_policy.total_debit = 0
am_tft_policy.n_steps_to_punish = 0
- am_tft_policy.on_episode_end()
+ am_tft_policy.observed_n_step_in_current_epi = base_env.get_unwrapped()[
+ 0
+ ].max_steps
+ am_tft_policy.on_episode_end(base_env=base_env)
assert am_tft_policy.total_debit == 0
assert am_tft_policy.n_steps_to_punish == 0
diff --git a/tests/marltoolbox/algos/exploiters/evader_utils.py b/tests/marltoolbox/algos/exploiters/evader_utils.py
index 9d9be6d..bcd110a 100644
--- a/tests/marltoolbox/algos/exploiters/evader_utils.py
+++ b/tests/marltoolbox/algos/exploiters/evader_utils.py
@@ -9,7 +9,7 @@
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.agents.dqn.dqn_torch_policy import postprocess_nstep_and_prio
from ray.rllib.agents.pg.pg_torch_policy import post_process_advantages
-from ray.rllib.agents.ppo.ppo_torch_policy import postprocess_ppo_gae
+from ray.rllib.agents.ppo.ppo_torch_policy import compute_gae_for_sample_batch
from ray.rllib.evaluation.sample_batch_builder import (
MultiAgentSampleBatchBuilder,
)
@@ -25,7 +25,7 @@
from marltoolbox.utils import postprocessing, miscellaneous
TEST_POLICIES = (
- (ppo.PPOTorchPolicy, postprocess_ppo_gae, ppo.DEFAULT_CONFIG),
+ (ppo.PPOTorchPolicy, compute_gae_for_sample_batch, ppo.DEFAULT_CONFIG),
(dqn.DQNTorchPolicy, postprocess_nstep_and_prio, dqn.DEFAULT_CONFIG),
(A3CTorchPolicy, add_advantages, a3c.DEFAULT_CONFIG),
(pg.PGTorchPolicy, post_process_advantages, pg.DEFAULT_CONFIG),
diff --git a/tests/marltoolbox/algos/test_welfare_coordination.py b/tests/marltoolbox/algos/test_welfare_coordination.py
new file mode 100644
index 0000000..046e1ca
--- /dev/null
+++ b/tests/marltoolbox/algos/test_welfare_coordination.py
@@ -0,0 +1,330 @@
+import random
+from marltoolbox.algos.welfare_coordination import MetaGameSolver
+
+BEST_PAYOFF = 1.5
+BEST_WELFARE = "best_welfare"
+WORST_PAYOFF = -0.5
+WORST_WELFARE = "worst_welfare"
+
+
+def test_end_to_end_wt_best_welfare_fn():
+ for _ in range(10):
+ meta_game_solver = _given_meta_game_with_a_clear_extrem_welfare_fn(
+ best=True
+ )
+ for _ in range(10):
+ _when_solving_meta_game(meta_game_solver)
+ _assert_best_welfare_in_announced_set(meta_game_solver)
+
+
+def test_end_to_end_wt_worst_welfare_fn():
+ for _ in range(10):
+ meta_game_solver = _given_meta_game_with_a_clear_extrem_welfare_fn(
+ worst=True
+ )
+ for _ in range(10):
+ _when_solving_meta_game(meta_game_solver)
+ _assert_best_welfare_not_in_announced_set(meta_game_solver)
+
+
+def _given_meta_game_with_a_clear_extrem_welfare_fn(best=False, worst=False):
+ assert best or worst
+ assert not (best and worst)
+ meta_game_solver = MetaGameSolver()
+ n_welfares = _get_random_number_of_welfare_fn()
+ own_player_idx = _get_random_position_of_players()
+ opp_player_idx = (own_player_idx + 1) % 2
+ welfares = ["welfare_" + str(el) for el in list(range(n_welfares - 1))]
+ if best:
+ welfares.append(BEST_WELFARE)
+ elif worst:
+ welfares.append(WORST_WELFARE)
+
+ all_welfare_pairs_wt_payoffs = (
+ _get_all_welfare_pairs_wt_extrem_payoffs_for_i(
+ welfares=welfares,
+ own_player_idx=own_player_idx,
+ best_welfare=BEST_WELFARE if best else None,
+ worst_welfare=WORST_WELFARE if worst else None,
+ )
+ )
+
+ meta_game_solver.setup_meta_game(
+ all_welfare_pairs_wt_payoffs,
+ own_player_idx=own_player_idx,
+ opp_player_idx=opp_player_idx,
+ own_default_welfare_fn=welfares[own_player_idx],
+ opp_default_welfare_fn=welfares[opp_player_idx],
+ )
+ return meta_game_solver
+
+
+def _when_solving_meta_game(meta_game_solver):
+ meta_game_solver.solve_meta_game(_get_random_tau())
+
+
+def _assert_best_welfare_in_announced_set(meta_game_solver):
+ assert BEST_WELFARE in meta_game_solver.welfare_set_to_annonce
+
+
+def _assert_best_welfare_not_in_announced_set(meta_game_solver):
+ assert BEST_WELFARE not in meta_game_solver.welfare_set_to_annonce
+
+
+def _get_random_tau():
+ return random.random()
+
+
+def _get_random_number_of_welfare_fn():
+ return random.randint(2, 4)
+
+
+def _get_random_position_of_players():
+ return random.randint(0, 1)
+
+
+def _get_all_welfare_pairs_wt_extrem_payoffs_for_i(
+ welfares,
+ own_player_idx,
+ best_welfare: str = None,
+ worst_welfare: str = None,
+):
+ all_welfare_pairs_wt_payoffs = {}
+ for welfare_p1 in welfares:
+ for welfare_p2 in welfares:
+ welfare_pair_name = (
+ MetaGameSolver.from_pair_of_welfare_names_to_key(
+ welfare_p1, welfare_p2
+ )
+ )
+
+ all_welfare_pairs_wt_payoffs[welfare_pair_name] = [
+ random.random(),
+ random.random(),
+ ]
+ if best_welfare is not None and best_welfare == welfare_p1:
+ all_welfare_pairs_wt_payoffs[welfare_pair_name][
+ own_player_idx
+ ] = BEST_PAYOFF
+ elif worst_welfare is not None and worst_welfare == welfare_p1:
+ all_welfare_pairs_wt_payoffs[welfare_pair_name][
+ own_player_idx
+ ] = WORST_PAYOFF
+ return all_welfare_pairs_wt_payoffs
+
+
+def test__compute_meta_payoff():
+ for _ in range(100):
+ (
+ welfares,
+ all_welfare_pairs_wt_payoffs,
+ own_welfare_set,
+ opp_welfare_set,
+ payoff,
+ payoff_default,
+ own_default_welfare_fn,
+ opp_default_welfare_fn,
+ own_player_idx,
+ opp_player_idx,
+ ) = _given_this_all_welfare_pairs_wt_payoffs()
+
+ meta_payoff = _when_computing_meta_game_payoff(
+ all_welfare_pairs_wt_payoffs,
+ own_player_idx,
+ opp_player_idx,
+ own_default_welfare_fn,
+ opp_default_welfare_fn,
+ own_welfare_set,
+ opp_welfare_set,
+ )
+
+ _assert_get_the_right_payoffs_or_default_payoff(
+ own_welfare_set,
+ opp_welfare_set,
+ own_player_idx,
+ meta_payoff,
+ payoff,
+ payoff_default,
+ )
+
+
+def _when_computing_meta_game_payoff(
+ all_welfare_pairs_wt_payoffs,
+ own_player_idx,
+ opp_player_idx,
+ own_default_welfare_fn,
+ opp_default_welfare_fn,
+ own_welfare_set,
+ opp_welfare_set,
+):
+ meta_game_solver = MetaGameSolver()
+ meta_game_solver.setup_meta_game(
+ all_welfare_pairs_wt_payoffs,
+ own_player_idx=own_player_idx,
+ opp_player_idx=opp_player_idx,
+ own_default_welfare_fn=own_default_welfare_fn,
+ opp_default_welfare_fn=opp_default_welfare_fn,
+ )
+ meta_payoff = meta_game_solver._compute_meta_payoff(
+ own_welfare_set, opp_welfare_set
+ )
+ return meta_payoff
+
+
+def _given_this_all_welfare_pairs_wt_payoffs():
+ n_welfares = _get_random_number_of_welfare_fn()
+ welfares = ["welfare_" + str(el) for el in list(range(n_welfares))]
+ own_player_idx = _get_random_position_of_players()
+ opp_player_idx = (own_player_idx + 1) % 2
+ all_welfare_pairs_wt_payoffs = {}
+
+ own_welfare_set, opp_welfare_set, payoff = _add_nominal_case(
+ welfares, all_welfare_pairs_wt_payoffs, own_player_idx
+ )
+
+ (
+ own_default_welfare_fn,
+ opp_default_welfare_fn,
+ payoff_default,
+ ) = _add_default_case(
+ welfares, all_welfare_pairs_wt_payoffs, own_player_idx
+ )
+
+ if (
+ own_default_welfare_fn
+ == list(own_welfare_set)[0]
+ == list(opp_welfare_set)[0]
+ == opp_default_welfare_fn
+ ):
+ payoff = payoff_default
+ return (
+ welfares,
+ all_welfare_pairs_wt_payoffs,
+ own_welfare_set,
+ opp_welfare_set,
+ payoff,
+ payoff_default,
+ own_default_welfare_fn,
+ opp_default_welfare_fn,
+ own_player_idx,
+ opp_player_idx,
+ )
+
+
+def _assert_get_the_right_payoffs_or_default_payoff(
+ own_welfare_set,
+ opp_welfare_set,
+ own_player_idx,
+ meta_payoff,
+ payoff,
+ payoff_default,
+):
+ if len(own_welfare_set & opp_welfare_set) > 0:
+ assert meta_payoff[own_player_idx] == payoff
+ else:
+ assert meta_payoff[own_player_idx] == payoff_default
+
+
+def _add_nominal_case(welfares, all_welfare_pairs_wt_payoffs, own_player_idx):
+ own_welfare_set = set(random.sample(welfares, 1))
+ opp_welfare_set = set(random.sample(welfares, 1))
+ welfare_pair_name = MetaGameSolver.from_pair_of_welfare_names_to_key(
+ list(own_welfare_set)[0], list(opp_welfare_set)[0]
+ )
+ payoff = random.random()
+ all_welfare_pairs_wt_payoffs[welfare_pair_name] = [-1, -1]
+ all_welfare_pairs_wt_payoffs[welfare_pair_name][own_player_idx] = payoff
+
+ return own_welfare_set, opp_welfare_set, payoff
+
+
+def _add_default_case(welfares, all_welfare_pairs_wt_payoffs, own_player_idx):
+ own_default_welfare_fn = random.sample(welfares, 1)[0]
+ opp_default_welfare_fn = random.sample(welfares, 1)[0]
+ welfare_default_pair_name = (
+ MetaGameSolver.from_pair_of_welfare_names_to_key(
+ own_default_welfare_fn, opp_default_welfare_fn
+ )
+ )
+ payoff_default = random.random()
+
+ all_welfare_pairs_wt_payoffs[welfare_default_pair_name] = [-1, -1]
+ all_welfare_pairs_wt_payoffs[welfare_default_pair_name][
+ own_player_idx
+ ] = payoff_default
+
+ return own_default_welfare_fn, opp_default_welfare_fn, payoff_default
+
+
+def test__list_all_set_of_welfare_fn():
+ for _ in range(100):
+ (
+ own_player_idx,
+ opp_player_idx,
+ welfares,
+ all_welfare_pairs_wt_payoffs,
+ ) = _given_n_welfare_fn()
+ meta_game_solver = _when_setting_the_game(
+ all_welfare_pairs_wt_payoffs, own_player_idx, opp_player_idx
+ )
+ _assert_right_number_of_sets_and_presence_of_single_and_pairs(
+ meta_game_solver, welfares
+ )
+
+
+def _given_n_welfare_fn():
+ n_welfares = _get_random_number_of_welfare_fn()
+ welfares = ["welfare_" + str(el) for el in list(range(n_welfares))]
+ own_player_idx = _get_random_position_of_players()
+ opp_player_idx = (own_player_idx + 1) % 2
+
+ all_welfare_pairs_wt_payoffs = (
+ _get_all_welfare_pairs_wt_extrem_payoffs_for_i(
+ welfares=welfares,
+ own_player_idx=own_player_idx,
+ best_welfare=welfares[0],
+ )
+ )
+ return (
+ own_player_idx,
+ opp_player_idx,
+ welfares,
+ all_welfare_pairs_wt_payoffs,
+ )
+
+
+def _when_setting_the_game(
+ all_welfare_pairs_wt_payoffs, own_player_idx, opp_player_idx
+):
+ meta_game_solver = MetaGameSolver()
+ meta_game_solver.setup_meta_game(
+ all_welfare_pairs_wt_payoffs,
+ own_player_idx=own_player_idx,
+ opp_player_idx=opp_player_idx,
+ own_default_welfare_fn="welfare_0",
+ opp_default_welfare_fn="welfare_1",
+ )
+ return meta_game_solver
+
+
+def _assert_right_number_of_sets_and_presence_of_single_and_pairs(
+ meta_game_solver, welfares
+):
+ meta_game_solver._list_all_set_of_welfare_fn()
+ if len(welfares) == 2:
+ assert len(meta_game_solver.welfare_fn_sets) == 3
+ elif len(welfares) == 3:
+ assert len(meta_game_solver.welfare_fn_sets) == 3 + 3 + 1
+ elif len(welfares) == 4:
+ print(
+ "meta_game_solver.welfare_fn_sets",
+ meta_game_solver.welfare_fn_sets,
+ )
+ assert len(meta_game_solver.welfare_fn_sets) == 4 + 6 + 4 + 1
+ for welfare in welfares:
+ assert frozenset([welfare]) in meta_game_solver.welfare_fn_sets
+ for welfare_2 in welfares:
+ assert (
+ frozenset([welfare, welfare_2])
+ in meta_game_solver.welfare_fn_sets
+ )
diff --git a/tests/marltoolbox/envs/coin_game_tests_utils.py b/tests/marltoolbox/envs/coin_game_tests_utils.py
new file mode 100644
index 0000000..6fe771f
--- /dev/null
+++ b/tests/marltoolbox/envs/coin_game_tests_utils.py
@@ -0,0 +1,440 @@
+import random
+import numpy as np
+
+
+def init_several_envs(classes, **kwargs):
+ return [init_env(env_class=class_, **kwargs) for class_ in classes]
+
+
+def init_env(
+ env_class,
+ max_steps,
+ seed=None,
+ grid_size=3,
+ players_can_pick_same_coin=True,
+ same_obs_for_each_player=False,
+ batch_size=None,
+):
+ config = {
+ "max_steps": max_steps,
+ "grid_size": grid_size,
+ "both_players_can_pick_the_same_coin": players_can_pick_same_coin,
+ "same_obs_for_each_player": same_obs_for_each_player,
+ }
+ if batch_size is not None:
+ config["batch_size"] = batch_size
+ env = env_class(config)
+ env.seed(seed)
+ return env
+
+
+def check_custom_obs(
+ obs,
+ grid_size,
+ batch_size=None,
+ n_in_0=1.0,
+ n_in_1=1.0,
+ n_in_2_and_above=1.0,
+ n_layers=4,
+):
+ assert len(obs) == 2, "two players"
+ for player_obs in obs.values():
+ if batch_size is None:
+ check_single_obs(
+ player_obs,
+ grid_size,
+ n_layers,
+ n_in_0,
+ n_in_1,
+ n_in_2_and_above,
+ )
+ else:
+ for i in range(batch_size):
+ check_single_obs(
+ player_obs[i, ...],
+ grid_size,
+ n_layers,
+ n_in_0,
+ n_in_1,
+ n_in_2_and_above,
+ )
+
+
+def check_single_obs(
+ player_obs, grid_size, n_layers, n_in_0, n_in_1, n_in_2_and_above
+):
+ assert player_obs.shape == (grid_size, grid_size, n_layers)
+ assert (
+ player_obs[..., 0].sum() == n_in_0
+ ), f"observe 1 player red in grid: {player_obs[..., 0]}"
+ assert (
+ player_obs[..., 1].sum() == n_in_1
+ ), f"observe 1 player blue in grid: {player_obs[..., 1]}"
+ assert (
+ player_obs[..., 2:].sum() == n_in_2_and_above
+ ), f"observe 1 coin in grid: {player_obs[..., 2:]}"
+
+
+def assert_logger_buffer_size(env, n_steps):
+ assert_attributes_len_equals_value(
+ env,
+ n_steps,
+ )
+
+
+def assert_attributes_len_equals_value(
+ object_,
+ value,
+ attributes=("red_pick", "red_pick_own", "blue_pick", "blue_pick_own"),
+):
+ for attribute in attributes:
+ assert len(getattr(object_, attribute)) == value
+
+
+def helper_test_reset(envs, check_obs_fn, **kwargs):
+ for env in envs:
+ obs = env.reset()
+ check_obs_fn(obs, **kwargs)
+ assert_logger_buffer_size(env, n_steps=0)
+
+
+def helper_test_step(envs, check_obs_fn, **kwargs):
+ for env in envs:
+ obs = env.reset()
+ check_obs_fn(obs, **kwargs)
+ assert_logger_buffer_size(env, n_steps=0)
+
+ actions = _get_random_action(env, **kwargs)
+ obs, reward, done, info = env.step(actions)
+ check_obs_fn(obs, **kwargs)
+ assert_logger_buffer_size(env, n_steps=1)
+ assert not done["__all__"]
+
+
+def _get_random_action(env, **kwargs):
+ if "batch_size" in kwargs.keys():
+ actions = _get_random_action_batch(env, kwargs["batch_size"])
+ else:
+ actions = _get_random_single_action(env)
+ return actions
+
+
+def _get_random_single_action(env):
+ actions = {
+ policy_id: random.randint(0, env.NUM_ACTIONS - 1)
+ for policy_id in env.players_ids
+ }
+ return actions
+
+
+def _get_random_action_batch(env, batch_size):
+ actions = {
+ policy_id: [
+ random.randint(0, env.NUM_ACTIONS - 1) for _ in range(batch_size)
+ ]
+ for policy_id in env.players_ids
+ }
+ return actions
+
+
+def helper_test_multiple_steps(envs, n_steps, check_obs_fn, **kwargs):
+ for env in envs:
+ obs = env.reset()
+ check_obs_fn(obs, **kwargs)
+ assert_logger_buffer_size(env, n_steps=0)
+
+ for step_i in range(1, n_steps, 1):
+ actions = _get_random_action(env, **kwargs)
+ obs, reward, done, info = env.step(actions)
+ check_obs_fn(obs, **kwargs)
+ assert_logger_buffer_size(env, n_steps=step_i)
+ assert not done["__all__"]
+
+
+def helper_test_multi_ple_episodes(
+ envs,
+ n_steps,
+ max_steps,
+ check_obs_fn,
+ **kwargs,
+):
+ for env in envs:
+ obs = env.reset()
+ check_obs_fn(obs, **kwargs)
+ assert_logger_buffer_size(env, n_steps=0)
+
+ step_i = 0
+ for _ in range(n_steps):
+ step_i += 1
+ actions = _get_random_action(env, **kwargs)
+ obs, reward, done, info = env.step(actions)
+ check_obs_fn(obs, **kwargs)
+ assert_logger_buffer_size(env, n_steps=step_i)
+ assert not done["__all__"] or (
+ step_i == max_steps and done["__all__"]
+ )
+ if done["__all__"]:
+ obs = env.reset()
+ check_obs_fn(obs, **kwargs)
+ assert_logger_buffer_size(env, n_steps=0)
+ step_i = 0
+
+
+def helper_assert_info(repetitions=10, **kwargs):
+ if "batch_size" in kwargs.keys():
+ for _ in range(repetitions):
+ batch_deltas = np.random.randint(
+ 0, kwargs["max_steps"] - 1, size=kwargs["batch_size"]
+ )
+ helper_assert_info_one_time(batch_deltas=batch_deltas, **kwargs)
+ else:
+ helper_assert_info_one_time(batch_deltas=None, **kwargs)
+
+
+def helper_assert_info_one_time(
+ n_steps,
+ p_red_act,
+ p_blue_act,
+ envs,
+ max_steps,
+ p_red_pos,
+ p_blue_pos,
+ c_red_pos,
+ c_blue_pos,
+ red_speed,
+ blue_speed,
+ red_own,
+ blue_own,
+ check_obs_fn,
+ overwrite_pos_fn,
+ c_red_coin=None,
+ batch_deltas=None,
+ blue_coop_fraction=None,
+ red_coop_fraction=None,
+ red_coop_speed=None,
+ blue_coop_speed=None,
+ delta_err=0.01,
+ **check_obs_kwargs,
+):
+ for env_i, env in enumerate(envs):
+ step_i = 0
+ obs = env.reset()
+ check_obs_fn(obs, **check_obs_kwargs)
+ assert_logger_buffer_size(env, n_steps=0)
+ _overwrite_pos_helper(
+ batch_deltas,
+ overwrite_pos_fn,
+ step_i,
+ max_steps,
+ env,
+ p_red_pos,
+ p_blue_pos,
+ c_red_pos,
+ c_blue_pos,
+ c_red_coin,
+ )
+
+ for _ in range(n_steps):
+ actions = _read_actions(
+ p_red_act,
+ p_blue_act,
+ step_i,
+ batch_deltas,
+ n_steps_in_epi=max_steps,
+ )
+ step_i += 1
+ obs, reward, done, info = env.step(actions)
+ check_obs_fn(obs, **check_obs_kwargs)
+ assert_logger_buffer_size(env, n_steps=step_i)
+ assert not done["__all__"] or (
+ step_i == max_steps and done["__all__"]
+ )
+
+ if done["__all__"]:
+ print("info", info)
+ print("step_i", step_i)
+ print("env", env)
+ print("env_i", env_i)
+ _assert_close_enough(
+ info["player_red"]["pick_speed"], red_speed, delta_err
+ )
+ _assert_close_enough(
+ info["player_blue"]["pick_speed"], blue_speed, delta_err
+ )
+ assert_not_present_in_dict_or_close_to(
+ "pick_own_color", red_own, info, "player_red", delta_err
+ )
+ assert_not_present_in_dict_or_close_to(
+ "pick_own_color", blue_own, info, "player_blue", delta_err
+ )
+ _assert_ssdmmcg_cooperation_items(
+ red_coop_fraction,
+ blue_coop_fraction,
+ red_coop_speed,
+ blue_coop_speed,
+ info,
+ delta_err,
+ )
+
+ obs = env.reset()
+ check_obs_fn(obs, **check_obs_kwargs)
+ assert_logger_buffer_size(env, n_steps=0)
+ step_i = 0
+
+ _overwrite_pos_helper(
+ batch_deltas,
+ overwrite_pos_fn,
+ step_i,
+ max_steps,
+ env,
+ p_red_pos,
+ p_blue_pos,
+ c_red_pos,
+ c_blue_pos,
+ c_red_coin,
+ )
+
+
+def assert_not_present_in_dict_or_close_to(
+ key, value, info, player, delta_err
+):
+ if value is None:
+ assert key not in info[player]
+ else:
+ _assert_close_enough(info[player][key], value, delta_err)
+
+
+def _assert_close_enough(value, target, delta_err):
+ assert abs(value - target) < delta_err, (
+ f"{abs(value - target)} <" f" {delta_err}"
+ )
+
+
+def _read_actions(
+ p_red_act, p_blue_act, step_i, batch_deltas=None, n_steps_in_epi=None
+):
+ if batch_deltas is not None:
+ return _read_actions_batch(
+ p_red_act, p_blue_act, step_i, batch_deltas, n_steps_in_epi
+ )
+ else:
+ return _read_single_action(p_red_act, p_blue_act, step_i)
+
+
+def _read_actions_batch(
+ p_red_act, p_blue_act, step_i, batch_deltas, n_steps_in_epi
+):
+ actions = {
+ "player_red": [
+ p_red_act[(step_i + delta) % n_steps_in_epi]
+ for delta in batch_deltas
+ ],
+ "player_blue": [
+ p_blue_act[(step_i + delta) % n_steps_in_epi]
+ for delta in batch_deltas
+ ],
+ }
+ return actions
+
+
+def _read_single_action(p_red_act, p_blue_act, step_i):
+ actions = {
+ "player_red": p_red_act[step_i - 1],
+ "player_blue": p_blue_act[step_i - 1],
+ }
+ return actions
+
+
+def _assert_ssdmmcg_cooperation_items(
+ red_coop_fraction,
+ blue_coop_fraction,
+ red_coop_speed,
+ blue_coop_speed,
+ info,
+ delta_err,
+):
+ if _is_using_ssdmmcg(
+ blue_coop_fraction,
+ red_coop_fraction,
+ red_coop_speed,
+ blue_coop_speed,
+ ):
+
+ assert_not_present_in_dict_or_close_to(
+ "blue_coop_fraction",
+ blue_coop_fraction,
+ info,
+ "player_blue",
+ delta_err,
+ )
+ assert_not_present_in_dict_or_close_to(
+ "red_coop_fraction",
+ red_coop_fraction,
+ info,
+ "player_red",
+ delta_err,
+ )
+ assert_not_present_in_dict_or_close_to(
+ "red_coop_speed",
+ red_coop_speed,
+ info,
+ "player_red",
+ delta_err,
+ )
+ assert_not_present_in_dict_or_close_to(
+ "blue_coop_speed",
+ blue_coop_speed,
+ info,
+ "player_blue",
+ delta_err,
+ )
+
+
+def _is_using_ssdmmcg(
+ blue_coop_fraction, red_coop_fraction, red_coop_speed, blue_coop_speed
+):
+ return (
+ blue_coop_fraction is not None
+ or red_coop_fraction is not None
+ or red_coop_speed is not None
+ or blue_coop_speed is not None
+ )
+
+
+def _overwrite_pos_helper(
+ batch_deltas,
+ overwrite_pos_fn,
+ step_i,
+ max_steps,
+ env,
+ p_red_pos,
+ p_blue_pos,
+ c_red_pos,
+ c_blue_pos,
+ c_red_coin,
+):
+ if batch_deltas is not None:
+ overwrite_pos_fn(
+ step_i,
+ batch_deltas,
+ max_steps,
+ env,
+ p_red_pos,
+ p_blue_pos,
+ c_red_pos,
+ c_blue_pos,
+ c_red_coin=c_red_coin,
+ )
+ else:
+ overwrite_pos_fn(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ c_red_coin=c_red_coin,
+ )
+
+
+def shift_consistently(list_, step_i, n_steps_in_epi, batch_deltas):
+ return [list_[(step_i + delta) % n_steps_in_epi] for delta in batch_deltas]
diff --git a/tests/marltoolbox/envs/test_coin_game.py b/tests/marltoolbox/envs/test_coin_game.py
index c62f338..4ab5264 100644
--- a/tests/marltoolbox/envs/test_coin_game.py
+++ b/tests/marltoolbox/envs/test_coin_game.py
@@ -5,137 +5,76 @@
from marltoolbox.envs.coin_game import CoinGame, AsymCoinGame
-
# TODO add tests for grid_size != 3
+from coin_game_tests_utils import (
+ check_custom_obs,
+ assert_logger_buffer_size,
+ helper_test_reset,
+ helper_test_step,
+ init_several_envs,
+ helper_test_multiple_steps,
+ helper_test_multi_ple_episodes,
+ helper_assert_info,
+)
+
+
+def init_my_envs(
+ max_steps,
+ grid_size,
+ players_can_pick_same_coin=True,
+ same_obs_for_each_player=True,
+):
+ return init_several_envs(
+ (CoinGame, AsymCoinGame),
+ max_steps=max_steps,
+ grid_size=grid_size,
+ players_can_pick_same_coin=players_can_pick_same_coin,
+ same_obs_for_each_player=same_obs_for_each_player,
+ )
+
def test_reset():
max_steps, grid_size = 20, 3
- envs = init_several_env(max_steps, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
-
-def init_several_env(max_steps, grid_size,
- players_can_pick_same_coin=True,
- same_obs_for_each_player=True):
- coin_game = init_env(
- max_steps,
- CoinGame,
- grid_size,
- players_can_pick_same_coin=players_can_pick_same_coin,
- same_obs_for_each_player=same_obs_for_each_player)
- asymm_coin_game = init_env(
- max_steps,
- AsymCoinGame,
- grid_size,
- players_can_pick_same_coin=players_can_pick_same_coin,
- same_obs_for_each_player=same_obs_for_each_player)
- return [coin_game, asymm_coin_game]
-
-
-def init_env(max_steps,
- env_class,
- seed=None,
- grid_size=3,
- players_can_pick_same_coin=True,
- same_obs_for_each_player=False):
- config = {
- "max_steps": max_steps,
- "grid_size": grid_size,
- "both_players_can_pick_the_same_coin": players_can_pick_same_coin,
- "same_obs_for_each_player": same_obs_for_each_player,
- }
- env = env_class(config)
- env.seed(seed)
- return env
+ envs = init_my_envs(max_steps, grid_size)
+ helper_test_reset(envs, check_obs, grid_size=grid_size)
def check_obs(obs, grid_size):
- assert len(obs) == 2, "two players"
- for key, player_obs in obs.items():
- assert player_obs.shape == (grid_size, grid_size, 4)
- assert player_obs[..., 0].sum() == 1.0, \
- f"observe 1 player red in grid: {player_obs[..., 0]}"
- assert player_obs[..., 1].sum() == 1.0, \
- f"observe 1 player blue in grid: {player_obs[..., 1]}"
- assert player_obs[..., 2:].sum() == 1.0, \
- f"observe 1 coin in grid: {player_obs[..., 0]}"
-
-
-def assert_logger_buffer_size(env, n_steps):
- assert len(env.red_pick) == n_steps
- assert len(env.red_pick_own) == n_steps
- assert len(env.blue_pick) == n_steps
- assert len(env.blue_pick_own) == n_steps
+ check_custom_obs(obs, grid_size)
def test_step():
max_steps, grid_size = 20, 3
- envs = init_several_env(max_steps, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- actions = {policy_id: random.randint(0, env.NUM_ACTIONS - 1)
- for policy_id in env.players_ids}
- obs, reward, done, info = env.step(actions)
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=1)
- assert not done["__all__"]
+ envs = init_my_envs(max_steps, grid_size)
+ helper_test_step(envs, check_obs, grid_size=grid_size)
def test_multiple_steps():
max_steps, grid_size = 20, 3
n_steps = int(max_steps * 0.75)
- envs = init_several_env(max_steps, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- for step_i in range(1, n_steps, 1):
- actions = {policy_id: random.randint(0, env.NUM_ACTIONS - 1)
- for policy_id in env.players_ids}
- obs, reward, done, info = env.step(actions)
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=step_i)
- assert not done["__all__"]
+ envs = init_my_envs(max_steps, grid_size)
+ helper_test_multiple_steps(
+ envs,
+ n_steps,
+ check_obs,
+ grid_size=grid_size,
+ )
def test_multiple_episodes():
max_steps, grid_size = 20, 3
n_steps = int(max_steps * 8.25)
- envs = init_several_env(max_steps, grid_size)
+ envs = init_my_envs(max_steps, grid_size)
+ helper_test_multi_ple_episodes(
+ envs,
+ n_steps,
+ max_steps,
+ check_obs,
+ grid_size=grid_size,
+ )
- for env in envs:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- step_i = 0
- for _ in range(n_steps):
- step_i += 1
- actions = {policy_id: random.randint(0, env.NUM_ACTIONS - 1)
- for policy_id in env.players_ids}
- obs, reward, done, info = env.step(actions)
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=step_i)
- assert not done["__all__"] or \
- (step_i == max_steps and done["__all__"])
- if done["__all__"]:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- step_i = 0
-
-
-def overwrite_pos(env, p_red_pos, p_blue_pos, c_red_pos, c_blue_pos):
+def overwrite_pos(env, p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, **kwargs):
assert c_red_pos is None or c_blue_pos is None
if c_red_pos is None:
env.red_coin = 0
@@ -154,46 +93,6 @@ def overwrite_pos(env, p_red_pos, p_blue_pos, c_red_pos, c_blue_pos):
env.red_coin = np.array(env.red_coin)
-def assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed, blue_speed, red_own, blue_own):
- step_i = 0
- delta_err = 0.01
- for _ in range(n_steps):
- step_i += 1
- actions = {"player_red": p_red_act[step_i - 1],
- "player_blue": p_blue_act[step_i - 1]}
- obs, reward, done, info = env.step(actions)
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=step_i)
- assert not done["__all__"] or (step_i == max_steps and done["__all__"])
-
- if done["__all__"]:
- assert abs(info["player_red"]["pick_speed"] - red_speed) \
- < delta_err
- assert abs(info["player_blue"]["pick_speed"] - blue_speed) \
- < delta_err
-
- if red_own is None:
- assert "pick_own_color" not in info["player_red"]
- else:
- assert abs(info["player_red"]["pick_own_color"] - red_own) \
- < delta_err
- if blue_own is None:
- assert "pick_own_color" not in info["player_blue"]
- else:
- assert abs(info["player_blue"]["pick_own_color"] - blue_own) \
- < delta_err
-
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- step_i = 0
-
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
-
-
def test_logged_info_no_picking():
p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
@@ -203,32 +102,51 @@ def test_logged_info_no_picking():
c_blue_pos = [None, None, None, None]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
+ envs = init_my_envs(max_steps, grid_size)
for env in envs:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None)
-
- envs = init_several_env(max_steps, grid_size,
- players_can_pick_same_coin=False)
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
+
+ envs = init_my_envs(max_steps, grid_size, players_can_pick_same_coin=False)
for env in envs:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None)
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__red_pick_red_all_the_time():
@@ -240,32 +158,47 @@ def test_logged_info__red_pick_red_all_the_time():
c_blue_pos = [None, None, None, None]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=0.0, red_own=1.0, blue_own=None)
-
- envs = init_several_env(max_steps, grid_size,
- players_can_pick_same_coin=False)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=0.0, red_own=1.0, blue_own=None)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=1.0,
+ blue_speed=0.0,
+ red_own=1.0,
+ blue_own=None,
+ )
+
+ envs = init_my_envs(max_steps, grid_size, players_can_pick_same_coin=False)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=1.0,
+ blue_speed=0.0,
+ red_own=1.0,
+ blue_own=None,
+ )
def test_logged_info__blue_pick_red_all_the_time():
@@ -277,32 +210,47 @@ def test_logged_info__blue_pick_red_all_the_time():
c_blue_pos = [None, None, None, None]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=0.0)
-
- envs = init_several_env(max_steps, grid_size,
- players_can_pick_same_coin=False)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=0.0)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=0.0,
+ blue_speed=1.0,
+ red_own=None,
+ blue_own=0.0,
+ )
+
+ envs = init_my_envs(max_steps, grid_size, players_can_pick_same_coin=False)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=0.0,
+ blue_speed=1.0,
+ red_own=None,
+ blue_own=0.0,
+ )
def test_logged_info__blue_pick_blue_all_the_time():
@@ -314,32 +262,47 @@ def test_logged_info__blue_pick_blue_all_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=1.0)
-
- envs = init_several_env(max_steps, grid_size,
- players_can_pick_same_coin=False)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=1.0)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=0.0,
+ blue_speed=1.0,
+ red_own=None,
+ blue_own=1.0,
+ )
+
+ envs = init_my_envs(max_steps, grid_size, players_can_pick_same_coin=False)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=0.0,
+ blue_speed=1.0,
+ red_own=None,
+ blue_own=1.0,
+ )
def test_logged_info__red_pick_blue_all_the_time():
@@ -351,32 +314,47 @@ def test_logged_info__red_pick_blue_all_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=0.0, red_own=0.0, blue_own=None)
-
- envs = init_several_env(max_steps, grid_size,
- players_can_pick_same_coin=False)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=0.0, red_own=0.0, blue_own=None)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=1.0,
+ blue_speed=0.0,
+ red_own=0.0,
+ blue_own=None,
+ )
+
+ envs = init_my_envs(max_steps, grid_size, players_can_pick_same_coin=False)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=1.0,
+ blue_speed=0.0,
+ red_own=0.0,
+ blue_own=None,
+ )
def test_logged_info__both_pick_blue_all_the_time():
@@ -388,18 +366,26 @@ def test_logged_info__both_pick_blue_all_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=1.0, red_own=0.0, blue_own=1.0)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=0.0,
+ blue_own=1.0,
+ )
def test_logged_info__both_pick_red_all_the_time():
@@ -411,18 +397,26 @@ def test_logged_info__both_pick_red_all_the_time():
c_blue_pos = [None, None, None, None]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=1.0, red_own=1.0, blue_own=0.0)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=1.0,
+ blue_own=0.0,
+ )
def test_logged_info__both_pick_red_half_the_time():
@@ -434,18 +428,26 @@ def test_logged_info__both_pick_red_half_the_time():
c_blue_pos = [None, None, None, None]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.5, blue_speed=0.5, red_own=1.0, blue_own=0.0)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=0.5,
+ blue_speed=0.5,
+ red_own=1.0,
+ blue_own=0.0,
+ )
def test_logged_info__both_pick_blue_half_the_time():
@@ -457,18 +459,26 @@ def test_logged_info__both_pick_blue_half_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.5, blue_speed=0.5, red_own=0.0, blue_own=1.0)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=0.5,
+ blue_speed=0.5,
+ red_own=0.0,
+ blue_own=1.0,
+ )
def test_logged_info__both_pick_blue():
@@ -480,18 +490,26 @@ def test_logged_info__both_pick_blue():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.25, blue_speed=0.5, red_own=0.0, blue_own=1.0)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=0.25,
+ blue_speed=0.5,
+ red_own=0.0,
+ blue_own=1.0,
+ )
def test_logged_info__pick_half_the_time_half_blue_half_red():
@@ -503,201 +521,420 @@ def test_logged_info__pick_half_the_time_half_blue_half_red():
c_blue_pos = [None, [1, 1], None, [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.5, blue_speed=0.5, red_own=0.5, blue_own=0.5)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=0.5,
+ blue_speed=0.5,
+ red_own=0.5,
+ blue_own=0.5,
+ )
def test_observations_are_invariant_to_the_player_trained_in_reset():
- p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0],
- [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]]
- p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1],
- [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]]
+ p_red_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [1, 1],
+ [2, 0],
+ [0, 1],
+ [2, 2],
+ [1, 2],
+ ]
+ p_blue_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 0],
+ [1, 2],
+ [2, 2],
+ ]
p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- c_red_pos = [[1, 1], None, [0, 1], None, None,
- [2, 2], [0, 0], None, None, [2, 1]]
- c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2],
- None, None, [0, 0], [2, 1], None]
+ c_red_pos = [
+ [1, 1],
+ None,
+ [0, 1],
+ None,
+ None,
+ [2, 2],
+ [0, 0],
+ None,
+ None,
+ [2, 1],
+ ]
+ c_blue_pos = [
+ None,
+ [1, 1],
+ None,
+ [0, 1],
+ [2, 2],
+ None,
+ None,
+ [0, 0],
+ [2, 1],
+ None,
+ ]
max_steps, grid_size = 10, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size,
- same_obs_for_each_player=False)
+ envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=False)
for env_i, env in enumerate(envs):
obs = env.reset()
assert_obs_is_symmetrical(obs, env)
step_i = 0
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
for _ in range(n_steps):
step_i += 1
- actions = {"player_red": p_red_act[step_i - 1],
- "player_blue": p_blue_act[step_i - 1]}
+ actions = {
+ "player_red": p_red_act[step_i - 1],
+ "player_blue": p_blue_act[step_i - 1],
+ }
_, _, _, _ = env.step(actions)
if step_i == max_steps:
break
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
def assert_obs_is_symmetrical(obs, env):
- assert np.all(obs[env.players_ids[0]][..., 0] ==
- obs[env.players_ids[1]][..., 1])
- assert np.all(obs[env.players_ids[1]][..., 0] ==
- obs[env.players_ids[0]][..., 1])
- assert np.all(obs[env.players_ids[0]][..., 2] ==
- obs[env.players_ids[1]][..., 3])
- assert np.all(obs[env.players_ids[1]][..., 2] ==
- obs[env.players_ids[0]][..., 3])
+ assert np.all(
+ obs[env.players_ids[0]][..., 0] == obs[env.players_ids[1]][..., 1]
+ )
+ assert np.all(
+ obs[env.players_ids[1]][..., 0] == obs[env.players_ids[0]][..., 1]
+ )
+ assert np.all(
+ obs[env.players_ids[0]][..., 2] == obs[env.players_ids[1]][..., 3]
+ )
+ assert np.all(
+ obs[env.players_ids[1]][..., 2] == obs[env.players_ids[0]][..., 3]
+ )
def test_observations_are_invariant_to_the_player_trained_in_step():
- p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0],
- [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]]
- p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1],
- [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]]
+ p_red_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [1, 1],
+ [2, 0],
+ [0, 1],
+ [2, 2],
+ [1, 2],
+ ]
+ p_blue_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 0],
+ [1, 2],
+ [2, 2],
+ ]
p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- c_red_pos = [[1, 1], None, [0, 1], None, None,
- [2, 2], [0, 0], None, None, [2, 1]]
- c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2],
- None, None, [0, 0], [2, 1], None]
+ c_red_pos = [
+ [1, 1],
+ None,
+ [0, 1],
+ None,
+ None,
+ [2, 2],
+ [0, 0],
+ None,
+ None,
+ [2, 1],
+ ]
+ c_blue_pos = [
+ None,
+ [1, 1],
+ None,
+ [0, 1],
+ [2, 2],
+ None,
+ None,
+ [0, 0],
+ [2, 1],
+ None,
+ ]
max_steps, grid_size = 10, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size,
- same_obs_for_each_player=False)
+ envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=False)
for env_i, env in enumerate(envs):
_ = env.reset()
step_i = 0
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
for _ in range(n_steps):
step_i += 1
- actions = {"player_red": p_red_act[step_i - 1],
- "player_blue": p_blue_act[step_i - 1]}
+ actions = {
+ "player_red": p_red_act[step_i - 1],
+ "player_blue": p_blue_act[step_i - 1],
+ }
obs, reward, done, info = env.step(actions)
# assert observations are symmetrical respective to the actions
if step_i % 2 == 1:
obs_step_odd = obs
elif step_i % 2 == 0:
- assert np.all(obs[env.players_ids[0]] ==
- obs_step_odd[env.players_ids[1]])
- assert np.all(obs[env.players_ids[1]] ==
- obs_step_odd[env.players_ids[0]])
+ assert np.all(
+ obs[env.players_ids[0]] == obs_step_odd[env.players_ids[1]]
+ )
+ assert np.all(
+ obs[env.players_ids[1]] == obs_step_odd[env.players_ids[0]]
+ )
assert_obs_is_symmetrical(obs, env)
if step_i == max_steps:
break
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
def test_observations_are_not_invariant_to_the_player_trained_in_reset():
- p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0],
- [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]]
- p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1],
- [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]]
+ p_red_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [1, 1],
+ [2, 0],
+ [0, 1],
+ [2, 2],
+ [1, 2],
+ ]
+ p_blue_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 0],
+ [1, 2],
+ [2, 2],
+ ]
p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- c_red_pos = [[1, 1], None, [0, 1], None, None,
- [2, 2], [0, 0], None, None, [2, 1]]
- c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2],
- None, None, [0, 0], [2, 1], None]
+ c_red_pos = [
+ [1, 1],
+ None,
+ [0, 1],
+ None,
+ None,
+ [2, 2],
+ [0, 0],
+ None,
+ None,
+ [2, 1],
+ ]
+ c_blue_pos = [
+ None,
+ [1, 1],
+ None,
+ [0, 1],
+ [2, 2],
+ None,
+ None,
+ [0, 0],
+ [2, 1],
+ None,
+ ]
max_steps, grid_size = 10, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size,
- same_obs_for_each_player=True)
+ envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=True)
for env_i, env in enumerate(envs):
obs = env.reset()
assert_obs_is_not_symmetrical(obs, env)
step_i = 0
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
for _ in range(n_steps):
step_i += 1
- actions = {"player_red": p_red_act[step_i - 1],
- "player_blue": p_blue_act[step_i - 1]}
+ actions = {
+ "player_red": p_red_act[step_i - 1],
+ "player_blue": p_blue_act[step_i - 1],
+ }
_, _, _, _ = env.step(actions)
if step_i == max_steps:
break
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
def assert_obs_is_not_symmetrical(obs, env):
- assert np.all(obs[env.players_ids[0]] ==
- obs[env.players_ids[1]])
+ assert np.all(obs[env.players_ids[0]] == obs[env.players_ids[1]])
def test_observations_are_not_invariant_to_the_player_trained_in_step():
- p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0],
- [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]]
- p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1],
- [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]]
+ p_red_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [1, 1],
+ [2, 0],
+ [0, 1],
+ [2, 2],
+ [1, 2],
+ ]
+ p_blue_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 0],
+ [1, 2],
+ [2, 2],
+ ]
p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- c_red_pos = [[1, 1], None, [0, 1], None, None,
- [2, 2], [0, 0], None, None, [2, 1]]
- c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2],
- None, None, [0, 0], [2, 1], None]
+ c_red_pos = [
+ [1, 1],
+ None,
+ [0, 1],
+ None,
+ None,
+ [2, 2],
+ [0, 0],
+ None,
+ None,
+ [2, 1],
+ ]
+ c_blue_pos = [
+ None,
+ [1, 1],
+ None,
+ [0, 1],
+ [2, 2],
+ None,
+ None,
+ [0, 0],
+ [2, 1],
+ None,
+ ]
max_steps, grid_size = 10, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size,
- same_obs_for_each_player=True)
+ envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=True)
for env_i, env in enumerate(envs):
_ = env.reset()
step_i = 0
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
for _ in range(n_steps):
step_i += 1
- actions = {"player_red": p_red_act[step_i - 1],
- "player_blue": p_blue_act[step_i - 1]}
+ actions = {
+ "player_red": p_red_act[step_i - 1],
+ "player_blue": p_blue_act[step_i - 1],
+ }
obs, reward, done, info = env.step(actions)
# assert observations are symmetrical respective to the actions
if step_i % 2 == 1:
obs_step_odd = obs
elif step_i % 2 == 0:
- assert np.any(obs[env.players_ids[0]] !=
- obs_step_odd[env.players_ids[1]])
- assert np.any(obs[env.players_ids[1]] !=
- obs_step_odd[env.players_ids[0]])
+ assert np.any(
+ obs[env.players_ids[0]] != obs_step_odd[env.players_ids[1]]
+ )
+ assert np.any(
+ obs[env.players_ids[1]] != obs_step_odd[env.players_ids[0]]
+ )
assert_obs_is_not_symmetrical(obs, env)
if step_i == max_steps:
break
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
@flaky(max_runs=4, min_passes=1)
def test_who_pick_is_random():
- size = 100
+ size = 1000
p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] * size
p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] * size
p_red_act = [0, 0, 0, 0] * size
@@ -706,29 +943,45 @@ def test_who_pick_is_random():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] * size
max_steps, grid_size = int(4 * size), 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=1.0, red_own=0.0, blue_own=1.0)
-
- envs = init_several_env(max_steps, grid_size,
- players_can_pick_same_coin=False)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.5, blue_speed=0.5, red_own=0.0, blue_own=1.0)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=0.0,
+ blue_own=1.0,
+ )
+
+ envs = init_my_envs(max_steps, grid_size, players_can_pick_same_coin=False)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ grid_size=grid_size,
+ red_speed=0.5,
+ blue_speed=0.5,
+ red_own=0.0,
+ blue_own=1.0,
+ delta_err=0.05,
+ )
diff --git a/tests/marltoolbox/envs/test_mixed_motive_coin_game.py b/tests/marltoolbox/envs/test_mixed_motive_coin_game.py
index 6cd0ecf..fb6e26e 100644
--- a/tests/marltoolbox/envs/test_mixed_motive_coin_game.py
+++ b/tests/marltoolbox/envs/test_mixed_motive_coin_game.py
@@ -3,132 +3,78 @@
import numpy as np
from marltoolbox.envs.mixed_motive_coin_game import MixedMotiveCoinGame
+from coin_game_tests_utils import (
+ check_custom_obs,
+ assert_logger_buffer_size,
+ helper_test_reset,
+ helper_test_step,
+ init_several_envs,
+ helper_test_multiple_steps,
+ helper_test_multi_ple_episodes,
+ helper_assert_info,
+)
# TODO add tests for grid_size != 3
-def test_reset():
- max_steps, grid_size = 20, 3
- envs = init_several_env(max_steps, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
-def init_several_env(max_steps, grid_size,
- players_can_pick_same_coin=True,
- same_obs_for_each_player=True):
- mixed_motive_coin_game = init_env(
- max_steps,
- MixedMotiveCoinGame,
- grid_size,
+def init_my_envs(
+ max_steps,
+ grid_size,
+ players_can_pick_same_coin=True,
+ same_obs_for_each_player=True,
+):
+ return init_several_envs(
+ classes=(MixedMotiveCoinGame,),
+ max_steps=max_steps,
+ grid_size=grid_size,
players_can_pick_same_coin=players_can_pick_same_coin,
- same_obs_for_each_player=same_obs_for_each_player)
- return [mixed_motive_coin_game]
-
-
-def init_env(max_steps,
- env_class,
- seed=None,
- grid_size=3,
- players_can_pick_same_coin=True,
- same_obs_for_each_player=False):
- config = {
- "max_steps": max_steps,
- "grid_size": grid_size,
- "both_players_can_pick_the_same_coin": players_can_pick_same_coin,
- "same_obs_for_each_player": same_obs_for_each_player,
- }
- env = env_class(config)
- env.seed(seed)
- return env
+ same_obs_for_each_player=same_obs_for_each_player,
+ )
def check_obs(obs, grid_size):
- assert len(obs) == 2, "two players"
- for key, player_obs in obs.items():
- assert player_obs.shape == (grid_size, grid_size, 4)
- assert player_obs[..., 0].sum() == 1.0, \
- f"observe 1 player red in grid: {player_obs[..., 0]}"
- assert player_obs[..., 1].sum() == 1.0, \
- f"observe 1 player blue in grid: {player_obs[..., 1]}"
- assert player_obs[..., 2:].sum() == 2.0, \
- f"observe 1 coin in grid: {player_obs[..., 0]}"
+ check_custom_obs(obs, grid_size, n_in_2_and_above=2.0)
-def assert_logger_buffer_size(env, n_steps):
- assert len(env.red_pick) == n_steps
- assert len(env.red_pick_own) == n_steps
- assert len(env.blue_pick) == n_steps
- assert len(env.blue_pick_own) == n_steps
+def test_reset():
+ max_steps, grid_size = 20, 3
+ envs = init_my_envs(max_steps, grid_size)
+ helper_test_reset(envs, check_obs, grid_size=grid_size)
def test_step():
max_steps, grid_size = 20, 3
- envs = init_several_env(max_steps, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- actions = {policy_id: random.randint(0, env.NUM_ACTIONS - 1)
- for policy_id in env.players_ids}
- obs, reward, done, info = env.step(actions)
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=1)
- assert not done["__all__"]
+ envs = init_my_envs(max_steps, grid_size)
+ helper_test_step(envs, check_obs, grid_size=grid_size)
def test_multiple_steps():
max_steps, grid_size = 20, 3
n_steps = int(max_steps * 0.75)
- envs = init_several_env(max_steps, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- for step_i in range(1, n_steps, 1):
- actions = {policy_id: random.randint(0, env.NUM_ACTIONS - 1)
- for policy_id in env.players_ids}
- obs, reward, done, info = env.step(actions)
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=step_i)
- assert not done["__all__"]
+ envs = init_my_envs(max_steps, grid_size)
+ helper_test_multiple_steps(
+ envs,
+ n_steps,
+ check_obs,
+ grid_size=grid_size,
+ )
def test_multiple_episodes():
max_steps, grid_size = 20, 3
n_steps = int(max_steps * 8.25)
- envs = init_several_env(max_steps, grid_size)
+ envs = init_my_envs(max_steps, grid_size)
+ helper_test_multi_ple_episodes(
+ envs,
+ n_steps,
+ max_steps,
+ check_obs,
+ grid_size=grid_size,
+ )
- for env in envs:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- step_i = 0
- for _ in range(n_steps):
- step_i += 1
- actions = {policy_id: random.randint(0, env.NUM_ACTIONS - 1)
- for policy_id in env.players_ids}
- obs, reward, done, info = env.step(actions)
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=step_i)
- assert not done["__all__"] or \
- (step_i == max_steps and done["__all__"])
- if done["__all__"]:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- step_i = 0
-
-
-def overwrite_pos(env, p_red_pos, p_blue_pos, c_red_pos, c_blue_pos):
+def overwrite_pos(env, p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, **kwargs):
env.red_pos = p_red_pos
env.blue_pos = p_blue_pos
env.red_coin_pos = c_red_pos
@@ -140,41 +86,6 @@ def overwrite_pos(env, p_red_pos, p_blue_pos, c_red_pos, c_blue_pos):
env.blue_coin_pos = np.array(env.blue_coin_pos)
-def assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed, blue_speed, red_own, blue_own):
- step_i = 0
- for _ in range(n_steps):
- step_i += 1
- actions = {"player_red": p_red_act[step_i - 1],
- "player_blue": p_blue_act[step_i - 1]}
- obs, reward, done, info = env.step(actions)
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=step_i)
- assert not done["__all__"] or (step_i == max_steps and done["__all__"])
-
- if done["__all__"]:
- assert info["player_red"]["pick_speed"] == red_speed
- assert info["player_blue"]["pick_speed"] == blue_speed
-
- if red_own is None:
- assert "pick_own_color" not in info["player_red"]
- else:
- assert info["player_red"]["pick_own_color"] == red_own
- if blue_own is None:
- assert "pick_own_color" not in info["player_blue"]
- else:
- assert info["player_blue"]["pick_own_color"] == blue_own
-
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- step_i = 0
-
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
-
-
def test_logged_info_no_picking():
p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
@@ -184,18 +95,34 @@ def test_logged_info_no_picking():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
+ envs = init_my_envs(max_steps, grid_size)
for env in envs:
obs = env.reset()
check_obs(obs, grid_size)
assert_logger_buffer_size(env, n_steps=0)
overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None)
+ env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]
+ )
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__red_pick_red_all_the_time():
@@ -207,18 +134,26 @@ def test_logged_info__red_pick_red_all_the_time():
c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__blue_pick_red_all_the_time():
@@ -230,18 +165,26 @@ def test_logged_info__blue_pick_red_all_the_time():
c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__blue_pick_blue_all_the_time():
@@ -253,18 +196,26 @@ def test_logged_info__blue_pick_blue_all_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__red_pick_blue_all_the_time():
@@ -276,18 +227,26 @@ def test_logged_info__red_pick_blue_all_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__both_pick_blue_all_the_time():
@@ -299,18 +258,26 @@ def test_logged_info__both_pick_blue_all_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=1.0, red_own=0.0, blue_own=1.0)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=0.0,
+ blue_own=1.0,
+ )
def test_logged_info__both_pick_red_all_the_time():
@@ -322,18 +289,26 @@ def test_logged_info__both_pick_red_all_the_time():
c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=1.0, red_own=1.0, blue_own=0.0)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=1.0,
+ blue_own=0.0,
+ )
def test_logged_info__both_pick_red_half_the_time():
@@ -345,18 +320,26 @@ def test_logged_info__both_pick_red_half_the_time():
c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__both_pick_blue_half_the_time():
@@ -368,18 +351,26 @@ def test_logged_info__both_pick_blue_half_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__both_pick_blue():
@@ -391,18 +382,26 @@ def test_logged_info__both_pick_blue():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__pick_half_the_time_half_blue_half_red():
@@ -414,193 +413,412 @@ def test_logged_info__pick_half_the_time_half_blue_half_red():
c_blue_pos = [[0, 0], [1, 1], [0, 0], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0])
-
- assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None)
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_observations_are_invariant_to_the_player_trained_in_reset():
- p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0],
- [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]]
- p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1],
- [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]]
+ p_red_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [1, 1],
+ [2, 0],
+ [0, 1],
+ [2, 2],
+ [1, 2],
+ ]
+ p_blue_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 0],
+ [1, 2],
+ [2, 2],
+ ]
p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- c_red_pos = [[1, 1], [0, 0], [0, 1], [0, 0], [0, 0],
- [2, 2], [0, 0], [0, 0], [0, 0], [2, 1]]
- c_blue_pos = [[0, 0], [1, 1], [0, 0], [0, 1], [2, 2],
- [0, 0], [0, 0], [0, 0], [2, 1], [0, 0]]
+ c_red_pos = [
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [0, 0],
+ [0, 0],
+ [2, 2],
+ [0, 0],
+ [0, 0],
+ [0, 0],
+ [2, 1],
+ ]
+ c_blue_pos = [
+ [0, 0],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 2],
+ [0, 0],
+ [0, 0],
+ [0, 0],
+ [2, 1],
+ [0, 0],
+ ]
max_steps, grid_size = 10, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size,
- same_obs_for_each_player=False)
+ envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=False)
for env_i, env in enumerate(envs):
obs = env.reset()
assert_obs_is_symmetrical(obs, env)
step_i = 0
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
for _ in range(n_steps):
step_i += 1
- actions = {"player_red": p_red_act[step_i - 1],
- "player_blue": p_blue_act[step_i - 1]}
+ actions = {
+ "player_red": p_red_act[step_i - 1],
+ "player_blue": p_blue_act[step_i - 1],
+ }
_, _, _, _ = env.step(actions)
if step_i == max_steps:
break
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
def assert_obs_is_symmetrical(obs, env):
- assert np.all(obs[env.players_ids[0]][..., 0] ==
- obs[env.players_ids[1]][..., 1])
- assert np.all(obs[env.players_ids[1]][..., 0] ==
- obs[env.players_ids[0]][..., 1])
- assert np.all(obs[env.players_ids[0]][..., 2] ==
- obs[env.players_ids[1]][..., 3])
- assert np.all(obs[env.players_ids[1]][..., 2] ==
- obs[env.players_ids[0]][..., 3])
+ assert np.all(
+ obs[env.players_ids[0]][..., 0] == obs[env.players_ids[1]][..., 1]
+ )
+ assert np.all(
+ obs[env.players_ids[1]][..., 0] == obs[env.players_ids[0]][..., 1]
+ )
+ assert np.all(
+ obs[env.players_ids[0]][..., 2] == obs[env.players_ids[1]][..., 3]
+ )
+ assert np.all(
+ obs[env.players_ids[1]][..., 2] == obs[env.players_ids[0]][..., 3]
+ )
def test_observations_are_invariant_to_the_player_trained_in_step():
- p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0],
- [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]]
- p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1],
- [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]]
+ p_red_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [1, 1],
+ [2, 0],
+ [0, 1],
+ [2, 2],
+ [1, 2],
+ ]
+ p_blue_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 0],
+ [1, 2],
+ [2, 2],
+ ]
p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- c_red_pos = [[1, 1], [0, 0], [0, 1], [0, 0], [0, 0],
- [2, 2], [0, 0], [0, 0], [0, 0], [2, 1]]
- c_blue_pos = [[0, 0], [1, 1], [0, 0], [0, 1], [2, 2],
- [0, 0], [0, 0], [0, 0], [2, 1], [0, 0]]
+ c_red_pos = [
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [0, 0],
+ [0, 0],
+ [2, 2],
+ [0, 0],
+ [0, 0],
+ [0, 0],
+ [2, 1],
+ ]
+ c_blue_pos = [
+ [0, 0],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 2],
+ [0, 0],
+ [0, 0],
+ [0, 0],
+ [2, 1],
+ [0, 0],
+ ]
max_steps, grid_size = 10, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size,
- same_obs_for_each_player=False)
+ envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=False)
for env_i, env in enumerate(envs):
_ = env.reset()
step_i = 0
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
for _ in range(n_steps):
step_i += 1
- actions = {"player_red": p_red_act[step_i - 1],
- "player_blue": p_blue_act[step_i - 1]}
+ actions = {
+ "player_red": p_red_act[step_i - 1],
+ "player_blue": p_blue_act[step_i - 1],
+ }
obs, reward, done, info = env.step(actions)
# assert observations are symmetrical respective to the actions
if step_i % 2 == 1:
obs_step_odd = obs
elif step_i % 2 == 0:
- assert np.all(obs[env.players_ids[0]] ==
- obs_step_odd[env.players_ids[1]])
- assert np.all(obs[env.players_ids[1]] ==
- obs_step_odd[env.players_ids[0]])
+ assert np.all(
+ obs[env.players_ids[0]] == obs_step_odd[env.players_ids[1]]
+ )
+ assert np.all(
+ obs[env.players_ids[1]] == obs_step_odd[env.players_ids[0]]
+ )
assert_obs_is_symmetrical(obs, env)
if step_i == max_steps:
break
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
def test_observations_are_not_invariant_to_the_player_trained_in_reset():
- p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0],
- [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]]
- p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1],
- [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]]
+ p_red_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [1, 1],
+ [2, 0],
+ [0, 1],
+ [2, 2],
+ [1, 2],
+ ]
+ p_blue_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 0],
+ [1, 2],
+ [2, 2],
+ ]
p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- c_red_pos = [[1, 1], [0, 0], [0, 1], [0, 0], [0, 0],
- [2, 2], [0, 0], [0, 0], [0, 0], [2, 1]]
- c_blue_pos = [[0, 0], [1, 1], [0, 0], [0, 1], [2, 2],
- [0, 0], [0, 0], [0, 0], [2, 1], [0, 0]]
+ c_red_pos = [
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [0, 0],
+ [0, 0],
+ [2, 2],
+ [0, 0],
+ [0, 0],
+ [0, 0],
+ [2, 1],
+ ]
+ c_blue_pos = [
+ [0, 0],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 2],
+ [0, 0],
+ [0, 0],
+ [0, 0],
+ [2, 1],
+ [0, 0],
+ ]
max_steps, grid_size = 10, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size,
- same_obs_for_each_player=True)
+ envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=True)
for env_i, env in enumerate(envs):
obs = env.reset()
assert_obs_is_not_symmetrical(obs, env)
step_i = 0
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
for _ in range(n_steps):
step_i += 1
- actions = {"player_red": p_red_act[step_i - 1],
- "player_blue": p_blue_act[step_i - 1]}
+ actions = {
+ "player_red": p_red_act[step_i - 1],
+ "player_blue": p_blue_act[step_i - 1],
+ }
_, _, _, _ = env.step(actions)
if step_i == max_steps:
break
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
def assert_obs_is_not_symmetrical(obs, env):
- assert np.all(obs[env.players_ids[0]] ==
- obs[env.players_ids[1]])
+ assert np.all(obs[env.players_ids[0]] == obs[env.players_ids[1]])
def test_observations_are_not_invariant_to_the_player_trained_in_step():
- p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0],
- [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]]
- p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1],
- [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]]
+ p_red_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [1, 1],
+ [2, 0],
+ [0, 1],
+ [2, 2],
+ [1, 2],
+ ]
+ p_blue_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 0],
+ [1, 2],
+ [2, 2],
+ ]
p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- c_red_pos = [[1, 1], [0, 0], [0, 1], [0, 0], [0, 0],
- [2, 2], [0, 0], [0, 0], [0, 0], [2, 1]]
- c_blue_pos = [[0, 0], [1, 1], [0, 0], [0, 1], [2, 2],
- [0, 0], [0, 0], [0, 0], [2, 1], [0, 0]]
+ c_red_pos = [
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [0, 0],
+ [0, 0],
+ [2, 2],
+ [0, 0],
+ [0, 0],
+ [0, 0],
+ [2, 1],
+ ]
+ c_blue_pos = [
+ [0, 0],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 2],
+ [0, 0],
+ [0, 0],
+ [0, 0],
+ [2, 1],
+ [0, 0],
+ ]
max_steps, grid_size = 10, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size,
- same_obs_for_each_player=True)
+ envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=True)
for env_i, env in enumerate(envs):
_ = env.reset()
step_i = 0
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
for _ in range(n_steps):
step_i += 1
- actions = {"player_red": p_red_act[step_i - 1],
- "player_blue": p_blue_act[step_i - 1]}
+ actions = {
+ "player_red": p_red_act[step_i - 1],
+ "player_blue": p_blue_act[step_i - 1],
+ }
obs, reward, done, info = env.step(actions)
# assert observations are symmetrical respective to the actions
if step_i % 2 == 1:
obs_step_odd = obs
elif step_i % 2 == 0:
- assert np.any(obs[env.players_ids[0]] !=
- obs_step_odd[env.players_ids[1]])
- assert np.any(obs[env.players_ids[1]] !=
- obs_step_odd[env.players_ids[0]])
+ assert np.any(
+ obs[env.players_ids[0]] != obs_step_odd[env.players_ids[1]]
+ )
+ assert np.any(
+ obs[env.players_ids[1]] != obs_step_odd[env.players_ids[0]]
+ )
assert_obs_is_not_symmetrical(obs, env)
if step_i == max_steps:
break
- overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i],
- c_red_pos[step_i], c_blue_pos[step_i])
+ overwrite_pos(
+ env,
+ p_red_pos[step_i],
+ p_blue_pos[step_i],
+ c_red_pos[step_i],
+ c_blue_pos[step_i],
+ )
diff --git a/tests/marltoolbox/envs/test_ssd_mixed_motive_coin_game.py b/tests/marltoolbox/envs/test_ssd_mixed_motive_coin_game.py
index 6d0c04f..63697b7 100644
--- a/tests/marltoolbox/envs/test_ssd_mixed_motive_coin_game.py
+++ b/tests/marltoolbox/envs/test_ssd_mixed_motive_coin_game.py
@@ -3,152 +3,86 @@
import numpy as np
from marltoolbox.envs.ssd_mixed_motive_coin_game import SSDMixedMotiveCoinGame
+from coin_game_tests_utils import (
+ check_custom_obs,
+ assert_logger_buffer_size,
+ helper_test_reset,
+ helper_test_step,
+ init_several_envs,
+ helper_test_multiple_steps,
+ helper_test_multi_ple_episodes,
+ helper_assert_info,
+)
# TODO add tests for grid_size != 3
-def test_reset():
- max_steps, grid_size = 20, 3
- envs = init_several_env(max_steps, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
-
-def init_several_env(
+def init_my_envs(
max_steps,
grid_size,
players_can_pick_same_coin=True,
same_obs_for_each_player=True,
):
- mixed_motive_coin_game = init_env(
- max_steps,
- SSDMixedMotiveCoinGame,
- grid_size,
+ return init_several_envs(
+ (SSDMixedMotiveCoinGame,),
+ max_steps=max_steps,
+ grid_size=grid_size,
players_can_pick_same_coin=players_can_pick_same_coin,
same_obs_for_each_player=same_obs_for_each_player,
)
- return [mixed_motive_coin_game]
-
-
-def init_env(
- max_steps,
- env_class,
- seed=None,
- grid_size=3,
- players_can_pick_same_coin=True,
- same_obs_for_each_player=False,
-):
- config = {
- "max_steps": max_steps,
- "grid_size": grid_size,
- "both_players_can_pick_the_same_coin": players_can_pick_same_coin,
- "same_obs_for_each_player": same_obs_for_each_player,
- }
- env = env_class(config)
- env.seed(seed)
- return env
def check_obs(obs, grid_size):
- assert len(obs) == 2, "two players"
- for key, player_obs in obs.items():
- assert player_obs.shape == (grid_size, grid_size, 6)
- assert (
- player_obs[..., 0].sum() == 1.0
- ), f"observe 1 player red in grid: {player_obs[..., 0]}"
- assert (
- player_obs[..., 1].sum() == 1.0
- ), f"observe 1 player blue in grid: {player_obs[..., 1]}"
- assert (
- player_obs[..., 2:].sum() == 2.0
- ), f"observe 1 coin in grid: {player_obs[..., 0]}"
-
-
-def assert_logger_buffer_size(env, n_steps):
- assert len(env.red_pick) == n_steps
- assert len(env.red_pick_own) == n_steps
- assert len(env.blue_pick) == n_steps
- assert len(env.blue_pick_own) == n_steps
+ check_custom_obs(obs, grid_size, n_in_2_and_above=2.0, n_layers=6)
-def test_step():
+def test_reset():
max_steps, grid_size = 20, 3
- envs = init_several_env(max_steps, grid_size)
+ envs = init_my_envs(max_steps, grid_size)
+ helper_test_reset(envs, check_obs, grid_size=grid_size)
- for env in envs:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- actions = {
- policy_id: random.randint(0, env.NUM_ACTIONS - 1)
- for policy_id in env.players_ids
- }
- obs, reward, done, info = env.step(actions)
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=1)
- assert not done["__all__"]
+def test_step():
+ max_steps, grid_size = 20, 3
+ envs = init_my_envs(max_steps, grid_size)
+ helper_test_step(envs, check_obs, grid_size=grid_size)
def test_multiple_steps():
max_steps, grid_size = 20, 3
n_steps = int(max_steps * 0.75)
- envs = init_several_env(max_steps, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- for step_i in range(1, n_steps, 1):
- actions = {
- policy_id: random.randint(0, env.NUM_ACTIONS - 1)
- for policy_id in env.players_ids
- }
- obs, reward, done, info = env.step(actions)
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=step_i)
- assert not done["__all__"]
+ envs = init_my_envs(max_steps, grid_size)
+ helper_test_multiple_steps(
+ envs,
+ n_steps,
+ check_obs,
+ grid_size=grid_size,
+ )
def test_multiple_episodes():
max_steps, grid_size = 20, 3
n_steps = int(max_steps * 8.25)
- envs = init_several_env(max_steps, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- step_i = 0
- for _ in range(n_steps):
- step_i += 1
- actions = {
- policy_id: random.randint(0, env.NUM_ACTIONS - 1)
- for policy_id in env.players_ids
- }
- obs, reward, done, info = env.step(actions)
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=step_i)
- assert not done["__all__"] or (
- step_i == max_steps and done["__all__"]
- )
- if done["__all__"]:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- step_i = 0
+ envs = init_my_envs(max_steps, grid_size)
+ helper_test_multi_ple_episodes(
+ envs,
+ n_steps,
+ max_steps,
+ check_obs,
+ grid_size=grid_size,
+ )
def overwrite_pos(
- env, p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, c_red_coin=True
+ env,
+ p_red_pos,
+ p_blue_pos,
+ c_red_pos,
+ c_blue_pos,
+ **kwargs,
):
- env.red_coin = c_red_coin
+ env.red_coin = True
env.red_pos = p_red_pos
env.blue_pos = p_blue_pos
env.red_coin_pos = c_red_pos
@@ -160,75 +94,6 @@ def overwrite_pos(
env.blue_coin_pos = np.array(env.blue_coin_pos)
-def assert_info(
- n_steps,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed,
- blue_speed,
- red_own,
- blue_own,
- blue_coop_fraction=None,
- red_coop_fraction=None,
- red_coop_speed=None,
- blue_coop_speed=None,
-):
- step_i = 0
- for _ in range(n_steps):
- step_i += 1
- actions = {
- "player_red": p_red_act[step_i - 1],
- "player_blue": p_blue_act[step_i - 1],
- }
- obs, reward, done, info = env.step(actions)
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=step_i)
- assert not done["__all__"] or (step_i == max_steps and done["__all__"])
-
- if done["__all__"]:
- assert info["player_red"]["pick_speed"] == red_speed
- assert info["player_blue"]["pick_speed"] == blue_speed
-
- assert_not_present_in_dict_or_equal(
- "pick_own_color", red_own, info, "player_red"
- )
- assert_not_present_in_dict_or_equal(
- "pick_own_color", blue_own, info, "player_blue"
- )
- assert_not_present_in_dict_or_equal(
- "blue_coop_fraction", blue_coop_fraction, info, "player_blue"
- )
- assert_not_present_in_dict_or_equal(
- "red_coop_fraction", red_coop_fraction, info, "player_red"
- )
- assert_not_present_in_dict_or_equal(
- "red_coop_speed", red_coop_speed, info, "player_red"
- )
- assert_not_present_in_dict_or_equal(
- "blue_coop_speed", blue_coop_speed, info, "player_blue"
- )
-
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- step_i = 0
-
- overwrite_pos(
- env,
- p_red_pos[step_i],
- p_blue_pos[step_i],
- c_red_pos[step_i],
- c_blue_pos[step_i],
- )
-
-
def assert_not_present_in_dict_or_equal(key, value, info, player):
if value is None:
assert key not in info[player]
@@ -245,36 +110,30 @@ def test_logged_info_no_picking():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]
- )
-
- assert_info(
- n_steps,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.0,
- red_own=None,
- blue_own=None,
- red_coop_speed=0.0,
- blue_coop_speed=0.0,
- red_coop_fraction=None,
- blue_coop_fraction=None,
- )
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ red_coop_speed=0.0,
+ blue_coop_speed=0.0,
+ red_coop_fraction=None,
+ blue_coop_fraction=None,
+ )
def test_logged_info__red_pick_red_all_the_time():
@@ -286,36 +145,30 @@ def test_logged_info__red_pick_red_all_the_time():
c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]
- )
-
- assert_info(
- n_steps,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.0,
- red_own=None,
- blue_own=None,
- red_coop_speed=0.0,
- blue_coop_speed=0.0,
- red_coop_fraction=None,
- blue_coop_fraction=None,
- )
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ red_coop_speed=0.0,
+ blue_coop_speed=0.0,
+ red_coop_fraction=None,
+ blue_coop_fraction=None,
+ )
def test_logged_info__blue_pick_red_all_the_time():
@@ -327,36 +180,30 @@ def test_logged_info__blue_pick_red_all_the_time():
c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]
- )
-
- assert_info(
- n_steps,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.0,
- red_own=None,
- blue_own=None,
- red_coop_speed=0.0,
- blue_coop_speed=0.0,
- red_coop_fraction=None,
- blue_coop_fraction=None,
- )
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ red_coop_speed=0.0,
+ blue_coop_speed=0.0,
+ red_coop_fraction=None,
+ blue_coop_fraction=None,
+ )
def test_logged_info__blue_pick_blue_all_the_time():
@@ -368,36 +215,30 @@ def test_logged_info__blue_pick_blue_all_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]
- )
-
- assert_info(
- n_steps,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=1.0,
- red_own=None,
- blue_own=1.0,
- red_coop_speed=0.0,
- blue_coop_speed=0.0,
- red_coop_fraction=None,
- blue_coop_fraction=0.0,
- )
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=1.0,
+ red_own=None,
+ blue_own=1.0,
+ red_coop_speed=0.0,
+ blue_coop_speed=0.0,
+ red_coop_fraction=None,
+ blue_coop_fraction=0.0,
+ )
def test_logged_info__red_pick_blue_all_the_time():
@@ -409,36 +250,30 @@ def test_logged_info__red_pick_blue_all_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]
- )
-
- assert_info(
- n_steps,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.0,
- red_own=None,
- blue_own=None,
- red_coop_speed=0.0,
- blue_coop_speed=0.0,
- red_coop_fraction=None,
- blue_coop_fraction=None,
- )
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ red_coop_speed=0.0,
+ blue_coop_speed=0.0,
+ red_coop_fraction=None,
+ blue_coop_fraction=None,
+ )
def test_logged_info__both_pick_blue_all_the_time():
@@ -450,36 +285,30 @@ def test_logged_info__both_pick_blue_all_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]
- )
-
- assert_info(
- n_steps,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=1.0,
- red_own=None,
- blue_own=1.0,
- red_coop_speed=0.0,
- blue_coop_speed=0.0,
- red_coop_fraction=None,
- blue_coop_fraction=0.0,
- )
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=1.0,
+ red_own=None,
+ blue_own=1.0,
+ red_coop_speed=0.0,
+ blue_coop_speed=0.0,
+ red_coop_fraction=None,
+ blue_coop_fraction=0.0,
+ )
def test_logged_info__both_pick_red_all_the_time():
@@ -491,36 +320,30 @@ def test_logged_info__both_pick_red_all_the_time():
c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]
- )
-
- assert_info(
- n_steps,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=1.0,
- blue_speed=1.0,
- red_own=1.0,
- blue_own=0.0,
- red_coop_speed=1.0,
- blue_coop_speed=0.0,
- red_coop_fraction=1.0,
- blue_coop_fraction=1.0,
- )
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=1.0,
+ blue_own=0.0,
+ red_coop_speed=1.0,
+ blue_coop_speed=0.0,
+ red_coop_fraction=1.0,
+ blue_coop_fraction=1.0,
+ )
def test_logged_info__both_pick_red_half_the_time():
@@ -532,36 +355,30 @@ def test_logged_info__both_pick_red_half_the_time():
c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]
- )
-
- assert_info(
- n_steps,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.0,
- red_own=None,
- blue_own=None,
- red_coop_speed=0.0,
- blue_coop_speed=0.0,
- red_coop_fraction=None,
- blue_coop_fraction=None,
- )
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ red_coop_speed=0.0,
+ blue_coop_speed=0.0,
+ red_coop_fraction=None,
+ blue_coop_fraction=None,
+ )
def test_logged_info__both_pick_blue_half_the_time():
@@ -573,36 +390,30 @@ def test_logged_info__both_pick_blue_half_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]
- )
-
- assert_info(
- n_steps,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.5,
- red_own=None,
- blue_own=1.0,
- red_coop_speed=0.0,
- blue_coop_speed=0.0,
- red_coop_fraction=None,
- blue_coop_fraction=0.0,
- )
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.5,
+ red_own=None,
+ blue_own=1.0,
+ red_coop_speed=0.0,
+ blue_coop_speed=0.0,
+ red_coop_fraction=None,
+ blue_coop_fraction=0.0,
+ )
def test_logged_info__both_pick_blue():
@@ -614,36 +425,30 @@ def test_logged_info__both_pick_blue():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]
- )
-
- assert_info(
- n_steps,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.5,
- red_own=None,
- blue_own=1.0,
- red_coop_speed=0.0,
- blue_coop_speed=0.0,
- red_coop_fraction=None,
- blue_coop_fraction=0.0,
- )
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.5,
+ red_own=None,
+ blue_own=1.0,
+ red_coop_speed=0.0,
+ blue_coop_speed=0.0,
+ red_coop_fraction=None,
+ blue_coop_fraction=0.0,
+ )
def test_logged_info__pick_half_the_time_half_blue_half_red():
@@ -655,36 +460,30 @@ def test_logged_info__pick_half_the_time_half_blue_half_red():
c_blue_pos = [[0, 0], [1, 1], [0, 0], [1, 1]]
max_steps, grid_size = 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, grid_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- overwrite_pos(
- env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]
- )
-
- assert_info(
- n_steps,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.25,
- red_own=None,
- blue_own=1.0,
- red_coop_speed=0.0,
- blue_coop_speed=0.0,
- red_coop_fraction=None,
- blue_coop_fraction=0.0,
- )
+ envs = init_my_envs(max_steps, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.25,
+ red_own=None,
+ blue_own=1.0,
+ red_coop_speed=0.0,
+ blue_coop_speed=0.0,
+ red_coop_fraction=None,
+ blue_coop_fraction=0.0,
+ )
def test_observations_are_not_invariant_to_the_player_trained_in_reset():
@@ -740,9 +539,7 @@ def test_observations_are_not_invariant_to_the_player_trained_in_reset():
]
max_steps, grid_size = 10, 3
n_steps = max_steps
- envs = init_several_env(
- max_steps, grid_size, same_obs_for_each_player=True
- )
+ envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=True)
for env_i, env in enumerate(envs):
obs = env.reset()
@@ -833,9 +630,7 @@ def test_observations_are_not_invariant_to_the_player_trained_in_step():
]
max_steps, grid_size = 10, 3
n_steps = max_steps
- envs = init_several_env(
- max_steps, grid_size, same_obs_for_each_player=True
- )
+ envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=True)
for env_i, env in enumerate(envs):
_ = env.reset()
diff --git a/tests/marltoolbox/envs/test_vectorized_coin_game.py b/tests/marltoolbox/envs/test_vectorized_coin_game.py
index 011ab7c..080e42c 100644
--- a/tests/marltoolbox/envs/test_vectorized_coin_game.py
+++ b/tests/marltoolbox/envs/test_vectorized_coin_game.py
@@ -4,160 +4,128 @@
import numpy as np
from flaky import flaky
-from marltoolbox.envs.vectorized_coin_game import VectorizedCoinGame, \
- AsymVectorizedCoinGame
-from test_coin_game import \
- assert_obs_is_symmetrical, assert_obs_is_not_symmetrical
+from coin_game_tests_utils import (
+ check_custom_obs,
+ helper_test_reset,
+ helper_test_step,
+ init_several_envs,
+ helper_test_multiple_steps,
+ helper_test_multi_ple_episodes,
+ helper_assert_info,
+ shift_consistently,
+)
+from marltoolbox.envs.vectorized_coin_game import (
+ VectorizedCoinGame,
+ AsymVectorizedCoinGame,
+)
+from test_coin_game import (
+ assert_obs_is_symmetrical,
+ assert_obs_is_not_symmetrical,
+)
# TODO add tests for grid_size != 3
-def test_reset():
- max_steps, batch_size, grid_size = 20, 5, 3
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
-
-def init_several_env(max_steps, batch_size, grid_size,
- players_can_pick_same_coin=True,
- same_obs_for_each_player=True):
- coin_game = init_env(max_steps, batch_size, VectorizedCoinGame, grid_size,
- players_can_pick_same_coin=players_can_pick_same_coin,
- same_obs_for_each_player=same_obs_for_each_player)
- asymm_coin_game = \
- init_env(max_steps, batch_size, AsymVectorizedCoinGame, grid_size,
- players_can_pick_same_coin=players_can_pick_same_coin,
- same_obs_for_each_player=same_obs_for_each_player)
- return [coin_game, asymm_coin_game]
-
-
-def init_env(max_steps, batch_size, env_class, seed=None, grid_size=3,
- players_can_pick_same_coin=True,
- same_obs_for_each_player=False):
- config = {
- "max_steps": max_steps,
- "batch_size": batch_size,
- "grid_size": grid_size,
- "same_obs_for_each_player": same_obs_for_each_player,
- "both_players_can_pick_the_same_coin": players_can_pick_same_coin,
- }
- env = env_class(config)
- env.seed(seed)
- return env
+def init_my_envs(
+ max_steps,
+ batch_size,
+ grid_size,
+ players_can_pick_same_coin=True,
+ same_obs_for_each_player=True,
+):
+ return init_several_envs(
+ classes=(VectorizedCoinGame, AsymVectorizedCoinGame),
+ max_steps=max_steps,
+ grid_size=grid_size,
+ batch_size=batch_size,
+ players_can_pick_same_coin=players_can_pick_same_coin,
+ same_obs_for_each_player=same_obs_for_each_player,
+ )
def check_obs(obs, batch_size, grid_size):
- assert len(obs) == 2, "two players"
- for i in range(batch_size):
- for key, player_obs in obs.items():
- assert player_obs.shape == (batch_size, grid_size, grid_size, 4)
- assert player_obs[i, ..., 0].sum() == 1.0, \
- f"observe 1 player red in grid: {player_obs[i, ..., 0]}"
- assert player_obs[i, ..., 1].sum() == 1.0, \
- f"observe 1 player blue in grid: {player_obs[i, ..., 1]}"
- assert player_obs[i, ..., 2:].sum() == 1.0, \
- f"observe 1 coin in grid: {player_obs[i, ..., 0]}"
+ check_custom_obs(obs, grid_size, batch_size=batch_size)
-def assert_logger_buffer_size(env, n_steps):
- assert len(env.red_pick) == n_steps
- assert len(env.red_pick_own) == n_steps
- assert len(env.blue_pick) == n_steps
- assert len(env.blue_pick_own) == n_steps
+def test_reset():
+ max_steps, batch_size, grid_size = 20, 5, 3
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+ helper_test_reset(
+ envs, check_obs, grid_size=grid_size, batch_size=batch_size
+ )
def test_step():
max_steps, batch_size, grid_size = 20, 5, 3
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- actions = {policy_id: [random.randint(0, env.NUM_ACTIONS - 1)
- for _ in range(batch_size)]
- for policy_id in env.players_ids}
- obs, reward, done, info = env.step(actions)
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=1)
- assert not done["__all__"]
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+ helper_test_step(
+ envs,
+ check_obs,
+ grid_size=grid_size,
+ batch_size=batch_size,
+ )
def test_multiple_steps():
max_steps, batch_size, grid_size = 20, 5, 3
n_steps = int(max_steps * 0.75)
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- for step_i in range(1, n_steps, 1):
- actions = {
- policy_id: [random.randint(0, env.NUM_ACTIONS - 1)
- for _ in range(batch_size)]
- for policy_id in env.players_ids}
- obs, reward, done, info = env.step(actions)
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=step_i)
- assert not done["__all__"]
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+ helper_test_multiple_steps(
+ envs,
+ n_steps,
+ check_obs,
+ grid_size=grid_size,
+ batch_size=batch_size,
+ )
def test_multiple_episodes():
max_steps, batch_size, grid_size = 20, 100, 3
n_steps = int(max_steps * 8.25)
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- step_i = 0
- for _ in range(n_steps):
- step_i += 1
- actions = {
- policy_id: [random.randint(0, env.NUM_ACTIONS - 1)
- for _ in range(batch_size)]
- for policy_id in env.players_ids}
- obs, reward, done, info = env.step(actions)
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=step_i)
- assert not done["__all__"] or (
- step_i == max_steps and done["__all__"])
- if done["__all__"]:
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- step_i = 0
-
-
-def overwrite_pos(step_i, batch_deltas, n_steps_in_epi, env, p_red_pos,
- p_blue_pos, c_red_pos, c_blue_pos):
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+ helper_test_multi_ple_episodes(
+ envs,
+ n_steps,
+ max_steps,
+ check_obs,
+ grid_size=grid_size,
+ batch_size=batch_size,
+ )
+
+
+def overwrite_pos(
+ step_i,
+ batch_deltas,
+ n_steps_in_epi,
+ env,
+ p_red_pos,
+ p_blue_pos,
+ c_red_pos,
+ c_blue_pos,
+ **kwargs,
+):
assert len(p_red_pos) == n_steps_in_epi
assert len(p_blue_pos) == n_steps_in_epi
assert len(c_red_pos) == n_steps_in_epi
assert len(c_blue_pos) == n_steps_in_epi
- env.red_coin = [0
- if c_red_pos[(step_i + delta) % n_steps_in_epi] is None
- else 1
- for delta in batch_deltas]
- coin_pos = [c_blue_pos[(step_i + delta) % n_steps_in_epi]
- if c_red_pos[(step_i + delta) % n_steps_in_epi] is None
- else c_red_pos[(step_i + delta) % n_steps_in_epi]
- for delta in batch_deltas]
-
- env.red_pos = [p_red_pos[(step_i + delta) % n_steps_in_epi] for delta in
- batch_deltas]
- env.blue_pos = [p_blue_pos[(step_i + delta) % n_steps_in_epi] for delta in
- batch_deltas]
+ env.red_coin = [
+ 0 if c_red_pos[(step_i + delta) % n_steps_in_epi] is None else 1
+ for delta in batch_deltas
+ ]
+ coin_pos = [
+ c_blue_pos[(step_i + delta) % n_steps_in_epi]
+ if c_red_pos[(step_i + delta) % n_steps_in_epi] is None
+ else c_red_pos[(step_i + delta) % n_steps_in_epi]
+ for delta in batch_deltas
+ ]
+
+ env.red_pos = shift_consistently(
+ p_red_pos, step_i, n_steps_in_epi, batch_deltas
+ )
+ env.blue_pos = shift_consistently(
+ p_blue_pos, step_i, n_steps_in_epi, batch_deltas
+ )
env.coin_pos = coin_pos
env.red_pos = np.array(env.red_pos)
@@ -166,52 +134,6 @@ def overwrite_pos(step_i, batch_deltas, n_steps_in_epi, env, p_red_pos,
env.red_coin = np.array(env.red_coin)
-def assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, env,
- grid_size, n_steps_in_epi,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed, blue_speed, red_own, blue_own):
- step_i = 0
- delta_err = 0.01
-
- for _ in range(n_steps):
- overwrite_pos(step_i, batch_deltas, n_steps_in_epi, env,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos)
- actions = {"player_red": [p_red_act[(step_i + delta) % n_steps_in_epi]
- for delta in batch_deltas],
- "player_blue": [
- p_blue_act[(step_i + delta) % n_steps_in_epi]
- for delta in batch_deltas]}
- step_i += 1
-
- obs, reward, done, info = env.step(actions)
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=step_i)
- assert not done["__all__"] or (
- step_i == n_steps_in_epi and done["__all__"])
-
- if done["__all__"]:
- assert abs(info["player_red"]["pick_speed"] - red_speed) \
- < delta_err
- assert abs(info["player_blue"]["pick_speed"] - blue_speed) \
- < delta_err
-
- if red_own is None:
- assert "pick_own_color" not in info["player_red"]
- else:
- assert abs(info["player_red"]["pick_own_color"] - red_own) \
- < delta_err
- if blue_own is None:
- assert "pick_own_color" not in info["player_blue"]
- else:
- assert abs(info["player_blue"]["pick_own_color"] - blue_own) \
- < delta_err
-
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- step_i = 0
-
-
def test_logged_info_no_picking():
p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
@@ -222,30 +144,51 @@ def test_logged_info_no_picking():
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
- for env in envs:
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None)
-
- envs = init_several_env(max_steps, batch_size, grid_size,
- players_can_pick_same_coin=False)
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
- for env in envs:
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
+
+ envs = init_my_envs(
+ max_steps, batch_size, grid_size, players_can_pick_same_coin=False
+ )
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__red_pick_red_all_the_time():
@@ -258,30 +201,51 @@ def test_logged_info__red_pick_red_all_the_time():
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=0.0, red_own=1.0, blue_own=None)
-
- envs = init_several_env(max_steps, batch_size, grid_size,
- players_can_pick_same_coin=False)
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=0.0, red_own=1.0, blue_own=None)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=0.0,
+ red_own=1.0,
+ blue_own=None,
+ )
+
+ envs = init_my_envs(
+ max_steps, batch_size, grid_size, players_can_pick_same_coin=False
+ )
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=0.0,
+ red_own=1.0,
+ blue_own=None,
+ )
def test_logged_info__blue_pick_red_all_the_time():
@@ -294,30 +258,51 @@ def test_logged_info__blue_pick_red_all_the_time():
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=0.0)
-
- envs = init_several_env(max_steps, batch_size, grid_size,
- players_can_pick_same_coin=False)
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=0.0)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=1.0,
+ red_own=None,
+ blue_own=0.0,
+ )
+
+ envs = init_my_envs(
+ max_steps, batch_size, grid_size, players_can_pick_same_coin=False
+ )
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=1.0,
+ red_own=None,
+ blue_own=0.0,
+ )
def test_logged_info__blue_pick_blue_all_the_time():
@@ -330,30 +315,51 @@ def test_logged_info__blue_pick_blue_all_the_time():
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=1.0)
-
- envs = init_several_env(max_steps, batch_size, grid_size,
- players_can_pick_same_coin=False)
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=1.0)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=1.0,
+ red_own=None,
+ blue_own=1.0,
+ )
+
+ envs = init_my_envs(
+ max_steps, batch_size, grid_size, players_can_pick_same_coin=False
+ )
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=1.0,
+ red_own=None,
+ blue_own=1.0,
+ )
def test_logged_info__red_pick_blue_all_the_time():
@@ -366,30 +372,51 @@ def test_logged_info__red_pick_blue_all_the_time():
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=0.0, red_own=0.0, blue_own=None)
-
- envs = init_several_env(max_steps, batch_size, grid_size,
- players_can_pick_same_coin=False)
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=0.0, red_own=0.0, blue_own=None)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=0.0,
+ red_own=0.0,
+ blue_own=None,
+ )
+
+ envs = init_my_envs(
+ max_steps, batch_size, grid_size, players_can_pick_same_coin=False
+ )
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=0.0,
+ red_own=0.0,
+ blue_own=None,
+ )
def test_logged_info__red_pick_blue_all_the_time_wt_difference_in_actions():
@@ -402,30 +429,51 @@ def test_logged_info__red_pick_blue_all_the_time_wt_difference_in_actions():
max_steps, batch_size, grid_size = 4, 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=0.0, red_own=0.0, blue_own=None)
-
- envs = init_several_env(max_steps, batch_size, grid_size,
- players_can_pick_same_coin=False)
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=0.0, red_own=0.0, blue_own=None)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=0.0,
+ red_own=0.0,
+ blue_own=None,
+ )
+
+ envs = init_my_envs(
+ max_steps, batch_size, grid_size, players_can_pick_same_coin=False
+ )
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=0.0,
+ red_own=0.0,
+ blue_own=None,
+ )
def test_logged_info__both_pick_blue_all_the_time():
@@ -438,17 +486,27 @@ def test_logged_info__both_pick_blue_all_the_time():
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=1.0, red_own=0.0, blue_own=1.0)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=0.0,
+ blue_own=1.0,
+ )
def test_logged_info__both_pick_red_all_the_time():
@@ -460,19 +518,27 @@ def test_logged_info__both_pick_red_all_the_time():
c_blue_pos = [None, None, None, None]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=1.0, red_own=1.0, blue_own=0.0)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=1.0,
+ blue_own=0.0,
+ )
def test_logged_info__both_pick_red_half_the_time():
@@ -484,19 +550,27 @@ def test_logged_info__both_pick_red_half_the_time():
c_blue_pos = [None, None, None, None]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.5, blue_speed=0.5, red_own=1.0, blue_own=0.0)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.5,
+ blue_speed=0.5,
+ red_own=1.0,
+ blue_own=0.0,
+ )
def test_logged_info__both_pick_blue_half_the_time():
@@ -508,19 +582,27 @@ def test_logged_info__both_pick_blue_half_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.5, blue_speed=0.5, red_own=0.0, blue_own=1.0)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.5,
+ blue_speed=0.5,
+ red_own=0.0,
+ blue_own=1.0,
+ )
def test_logged_info__both_pick_blue():
@@ -532,19 +614,27 @@ def test_logged_info__both_pick_blue():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.25, blue_speed=0.5, red_own=0.0, blue_own=1.0)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.25,
+ blue_speed=0.5,
+ red_own=0.0,
+ blue_own=1.0,
+ )
def test_logged_info__pick_half_the_time_half_blue_half_red():
@@ -556,25 +646,33 @@ def test_logged_info__pick_half_the_time_half_blue_half_red():
c_blue_pos = [None, [1, 1], None, [1, 1]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.5, blue_speed=0.5, red_own=0.5, blue_own=0.5)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.5,
+ blue_speed=0.5,
+ red_own=0.5,
+ blue_own=0.5,
+ )
def test_get_and_set_env_state():
max_steps, batch_size, grid_size = 20, 100, 3
n_steps = int(max_steps * 8.25)
- envs = init_several_env(max_steps, batch_size, grid_size)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
for env in envs:
obs = env.reset()
@@ -585,15 +683,23 @@ def test_get_and_set_env_state():
step_i = 0
for _ in range(n_steps):
step_i += 1
- actions = {policy_id: [random.randint(0, env.NUM_ACTIONS - 1)
- for _ in range(batch_size)]
- for policy_id in env.players_ids}
+ actions = {
+ policy_id: [
+ random.randint(0, env.NUM_ACTIONS - 1)
+ for _ in range(batch_size)
+ ]
+ for policy_id in env.players_ids
+ }
obs, reward, done, info = env.step(actions)
- assert all([v == initial_env_state_saved[k]
- if not isinstance(v, np.ndarray)
- else (v == initial_env_state_saved[k]).all()
- for k, v in initial_env_state.items()])
+ assert all(
+ [
+ v == initial_env_state_saved[k]
+ if not isinstance(v, np.ndarray)
+ else (v == initial_env_state_saved[k]).all()
+ for k, v in initial_env_state.items()
+ ]
+ )
env_state_after_step = env._save_env()
env_after_step = copy.deepcopy(env)
@@ -601,19 +707,27 @@ def test_get_and_set_env_state():
env_vars, env_initial_vars = vars(env), vars(env_initial)
env_vars.pop("np_random", None)
env_initial_vars.pop("np_random", None)
- assert all([v == env_initial_vars[k]
- if not isinstance(v, np.ndarray)
- else (v == env_initial_vars[k]).all()
- for k, v in env_vars.items()])
+ assert all(
+ [
+ v == env_initial_vars[k]
+ if not isinstance(v, np.ndarray)
+ else (v == env_initial_vars[k]).all()
+ for k, v in env_vars.items()
+ ]
+ )
env._load_env(env_state_after_step)
env_vars, env_after_step_vars = vars(env), vars(env_after_step)
env_vars.pop("np_random", None)
env_after_step_vars.pop("np_random", None)
- assert all([v == env_after_step_vars[k]
- if not isinstance(v, np.ndarray)
- else (v == env_after_step_vars[k]).all()
- for k, v in env_vars.items()])
+ assert all(
+ [
+ v == env_after_step_vars[k]
+ if not isinstance(v, np.ndarray)
+ else (v == env_after_step_vars[k]).all()
+ for k, v in env_vars.items()
+ ]
+ )
if done["__all__"]:
obs = env.reset()
@@ -621,36 +735,92 @@ def test_get_and_set_env_state():
def test_observations_are_invariant_to_the_player_trained_wt_step():
- p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0],
- [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]]
- p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1],
- [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]]
+ p_red_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [1, 1],
+ [2, 0],
+ [0, 1],
+ [2, 2],
+ [1, 2],
+ ]
+ p_blue_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 0],
+ [1, 2],
+ [2, 2],
+ ]
p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- c_red_pos = [[1, 1], None, [0, 1], None, None,
- [2, 2], [0, 0], None, None, [2, 1]]
- c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2],
- None, None, [0, 0], [2, 1], None]
+ c_red_pos = [
+ [1, 1],
+ None,
+ [0, 1],
+ None,
+ None,
+ [2, 2],
+ [0, 0],
+ None,
+ None,
+ [2, 1],
+ ]
+ c_blue_pos = [
+ None,
+ [1, 1],
+ None,
+ [0, 1],
+ [2, 2],
+ None,
+ None,
+ [0, 0],
+ [2, 1],
+ None,
+ ]
max_steps, batch_size, grid_size = 10, 52, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size,
- same_obs_for_each_player=False)
+ envs = init_my_envs(
+ max_steps, batch_size, grid_size, same_obs_for_each_player=False
+ )
- batch_deltas = [i % max_steps if i % 2 == 0 else i % max_steps - 1
- for i in range(batch_size)]
+ batch_deltas = [
+ i % max_steps if i % 2 == 0 else i % max_steps - 1
+ for i in range(batch_size)
+ ]
for env_i, env in enumerate(envs):
_ = env.reset()
step_i = 0
for _ in range(n_steps):
- overwrite_pos(step_i, batch_deltas, max_steps, env, p_red_pos,
- p_blue_pos,
- c_red_pos, c_blue_pos)
- actions = {"player_red": [p_red_act[(step_i + delta) % max_steps]
- for delta in batch_deltas],
- "player_blue": [p_blue_act[(step_i + delta) % max_steps]
- for delta in batch_deltas]}
+ overwrite_pos(
+ step_i,
+ batch_deltas,
+ max_steps,
+ env,
+ p_red_pos,
+ p_blue_pos,
+ c_red_pos,
+ c_blue_pos,
+ )
+ actions = {
+ "player_red": [
+ p_red_act[(step_i + delta) % max_steps]
+ for delta in batch_deltas
+ ],
+ "player_blue": [
+ p_blue_act[(step_i + delta) % max_steps]
+ for delta in batch_deltas
+ ],
+ }
obs, reward, done, info = env.step(actions)
step_i += 1
@@ -659,11 +829,11 @@ def test_observations_are_invariant_to_the_player_trained_wt_step():
obs_step_odd = obs
elif step_i % 2 == 0:
assert np.all(
- obs[env.players_ids[0]] == obs_step_odd[
- env.players_ids[1]])
+ obs[env.players_ids[0]] == obs_step_odd[env.players_ids[1]]
+ )
assert np.all(
- obs[env.players_ids[1]] == obs_step_odd[
- env.players_ids[0]])
+ obs[env.players_ids[1]] == obs_step_odd[env.players_ids[0]]
+ )
assert_obs_is_symmetrical(obs, env)
if step_i == max_steps:
@@ -671,23 +841,66 @@ def test_observations_are_invariant_to_the_player_trained_wt_step():
def test_observations_are_invariant_to_the_player_trained_wt_reset():
- p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0],
- [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]]
- p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1],
- [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]]
+ p_red_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [1, 1],
+ [2, 0],
+ [0, 1],
+ [2, 2],
+ [1, 2],
+ ]
+ p_blue_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 0],
+ [1, 2],
+ [2, 2],
+ ]
p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- c_red_pos = [[1, 1], None, [0, 1], None, None,
- [2, 2], [0, 0], None, None, [2, 1]]
- c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2],
- None, None, [0, 0], [2, 1], None]
+ c_red_pos = [
+ [1, 1],
+ None,
+ [0, 1],
+ None,
+ None,
+ [2, 2],
+ [0, 0],
+ None,
+ None,
+ [2, 1],
+ ]
+ c_blue_pos = [
+ None,
+ [1, 1],
+ None,
+ [0, 1],
+ [2, 2],
+ None,
+ None,
+ [0, 0],
+ [2, 1],
+ None,
+ ]
max_steps, batch_size, grid_size = 10, 52, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size,
- same_obs_for_each_player=False)
+ envs = init_my_envs(
+ max_steps, batch_size, grid_size, same_obs_for_each_player=False
+ )
- batch_deltas = [i % max_steps if i % 2 == 0 else i % max_steps - 1
- for i in range(batch_size)]
+ batch_deltas = [
+ i % max_steps if i % 2 == 0 else i % max_steps - 1
+ for i in range(batch_size)
+ ]
for env_i, env in enumerate(envs):
obs = env.reset()
@@ -695,12 +908,26 @@ def test_observations_are_invariant_to_the_player_trained_wt_reset():
step_i = 0
for _ in range(n_steps):
- overwrite_pos(step_i, batch_deltas, max_steps, env, p_red_pos,
- p_blue_pos, c_red_pos, c_blue_pos)
- actions = {"player_red": [p_red_act[(step_i + delta) % max_steps]
- for delta in batch_deltas],
- "player_blue": [p_blue_act[(step_i + delta) % max_steps]
- for delta in batch_deltas]}
+ overwrite_pos(
+ step_i,
+ batch_deltas,
+ max_steps,
+ env,
+ p_red_pos,
+ p_blue_pos,
+ c_red_pos,
+ c_blue_pos,
+ )
+ actions = {
+ "player_red": [
+ p_red_act[(step_i + delta) % max_steps]
+ for delta in batch_deltas
+ ],
+ "player_blue": [
+ p_blue_act[(step_i + delta) % max_steps]
+ for delta in batch_deltas
+ ],
+ }
_, _, _, _ = env.step(actions)
step_i += 1
@@ -710,36 +937,92 @@ def test_observations_are_invariant_to_the_player_trained_wt_reset():
def test_observations_are_not_invariant_to_the_player_trained_wt_step():
- p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0],
- [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]]
- p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1],
- [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]]
+ p_red_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [1, 1],
+ [2, 0],
+ [0, 1],
+ [2, 2],
+ [1, 2],
+ ]
+ p_blue_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 0],
+ [1, 2],
+ [2, 2],
+ ]
p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- c_red_pos = [[1, 1], None, [0, 1], None, None,
- [2, 2], [0, 0], None, None, [2, 1]]
- c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2],
- None, None, [0, 0], [2, 1], None]
+ c_red_pos = [
+ [1, 1],
+ None,
+ [0, 1],
+ None,
+ None,
+ [2, 2],
+ [0, 0],
+ None,
+ None,
+ [2, 1],
+ ]
+ c_blue_pos = [
+ None,
+ [1, 1],
+ None,
+ [0, 1],
+ [2, 2],
+ None,
+ None,
+ [0, 0],
+ [2, 1],
+ None,
+ ]
max_steps, batch_size, grid_size = 10, 52, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size,
- same_obs_for_each_player=True)
+ envs = init_my_envs(
+ max_steps, batch_size, grid_size, same_obs_for_each_player=True
+ )
- batch_deltas = [i % max_steps if i % 2 == 0 else i % max_steps - 1
- for i in range(batch_size)]
+ batch_deltas = [
+ i % max_steps if i % 2 == 0 else i % max_steps - 1
+ for i in range(batch_size)
+ ]
for env_i, env in enumerate(envs):
_ = env.reset()
step_i = 0
for _ in range(n_steps):
- overwrite_pos(step_i, batch_deltas, max_steps, env, p_red_pos,
- p_blue_pos,
- c_red_pos, c_blue_pos)
- actions = {"player_red": [p_red_act[(step_i + delta) % max_steps]
- for delta in batch_deltas],
- "player_blue": [p_blue_act[(step_i + delta) % max_steps]
- for delta in batch_deltas]}
+ overwrite_pos(
+ step_i,
+ batch_deltas,
+ max_steps,
+ env,
+ p_red_pos,
+ p_blue_pos,
+ c_red_pos,
+ c_blue_pos,
+ )
+ actions = {
+ "player_red": [
+ p_red_act[(step_i + delta) % max_steps]
+ for delta in batch_deltas
+ ],
+ "player_blue": [
+ p_blue_act[(step_i + delta) % max_steps]
+ for delta in batch_deltas
+ ],
+ }
obs, reward, done, info = env.step(actions)
step_i += 1
@@ -750,11 +1033,11 @@ def test_observations_are_not_invariant_to_the_player_trained_wt_step():
obs_step_odd = obs
elif step_i % 2 == 0:
assert np.any(
- obs[env.players_ids[0]] != obs_step_odd[env.players_ids[
- 1]])
+ obs[env.players_ids[0]] != obs_step_odd[env.players_ids[1]]
+ )
assert np.any(
- obs[env.players_ids[1]] != obs_step_odd[env.players_ids[
- 0]])
+ obs[env.players_ids[1]] != obs_step_odd[env.players_ids[0]]
+ )
assert_obs_is_not_symmetrical(obs, env)
if step_i == max_steps:
@@ -762,23 +1045,66 @@ def test_observations_are_not_invariant_to_the_player_trained_wt_step():
def test_observations_are_not_invariant_to_the_player_trained_wt_reset():
- p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0],
- [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]]
- p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1],
- [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]]
+ p_red_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [1, 1],
+ [2, 0],
+ [0, 1],
+ [2, 2],
+ [1, 2],
+ ]
+ p_blue_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 0],
+ [1, 2],
+ [2, 2],
+ ]
p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- c_red_pos = [[1, 1], None, [0, 1], None, None,
- [2, 2], [0, 0], None, None, [2, 1]]
- c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2],
- None, None, [0, 0], [2, 1], None]
+ c_red_pos = [
+ [1, 1],
+ None,
+ [0, 1],
+ None,
+ None,
+ [2, 2],
+ [0, 0],
+ None,
+ None,
+ [2, 1],
+ ]
+ c_blue_pos = [
+ None,
+ [1, 1],
+ None,
+ [0, 1],
+ [2, 2],
+ None,
+ None,
+ [0, 0],
+ [2, 1],
+ None,
+ ]
max_steps, batch_size, grid_size = 10, 52, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size,
- same_obs_for_each_player=True)
+ envs = init_my_envs(
+ max_steps, batch_size, grid_size, same_obs_for_each_player=True
+ )
- batch_deltas = [i % max_steps if i % 2 == 0 else i % max_steps - 1
- for i in range(batch_size)]
+ batch_deltas = [
+ i % max_steps if i % 2 == 0 else i % max_steps - 1
+ for i in range(batch_size)
+ ]
for env_i, env in enumerate(envs):
obs = env.reset()
@@ -786,12 +1112,26 @@ def test_observations_are_not_invariant_to_the_player_trained_wt_reset():
step_i = 0
for _ in range(n_steps):
- overwrite_pos(step_i, batch_deltas, max_steps, env, p_red_pos,
- p_blue_pos, c_red_pos, c_blue_pos)
- actions = {"player_red": [p_red_act[(step_i + delta) % max_steps]
- for delta in batch_deltas],
- "player_blue": [p_blue_act[(step_i + delta) % max_steps]
- for delta in batch_deltas]}
+ overwrite_pos(
+ step_i,
+ batch_deltas,
+ max_steps,
+ env,
+ p_red_pos,
+ p_blue_pos,
+ c_red_pos,
+ c_blue_pos,
+ )
+ actions = {
+ "player_red": [
+ p_red_act[(step_i + delta) % max_steps]
+ for delta in batch_deltas
+ ],
+ "player_blue": [
+ p_blue_act[(step_i + delta) % max_steps]
+ for delta in batch_deltas
+ ],
+ }
_, _, _, _ = env.step(actions)
step_i += 1
@@ -812,27 +1152,50 @@ def test_who_pick_is_random():
max_steps, batch_size, grid_size = int(4 * size), 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=1.0, blue_speed=1.0, red_own=1.0, blue_own=0.0)
-
- envs = init_several_env(max_steps, batch_size, grid_size,
- players_can_pick_same_coin=False)
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act,
- env, grid_size, max_steps,
- p_red_pos, p_blue_pos, c_red_pos, c_blue_pos,
- red_speed=0.5, blue_speed=0.5, red_own=1.0, blue_own=0.0)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=1.0,
+ blue_own=0.0,
+ )
+
+ envs = init_my_envs(
+ max_steps, batch_size, grid_size, players_can_pick_same_coin=False
+ )
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.5,
+ blue_speed=0.5,
+ red_own=1.0,
+ blue_own=0.0,
+ repetitions=1,
+ delta_err=0.05,
+ )
diff --git a/tests/marltoolbox/envs/test_mixed_motive_vectorized_coin_game.py b/tests/marltoolbox/envs/test_vectorized_mixed_motive_coin_game.py
similarity index 51%
rename from tests/marltoolbox/envs/test_mixed_motive_vectorized_coin_game.py
rename to tests/marltoolbox/envs/test_vectorized_mixed_motive_coin_game.py
index 6dc6107..cfc7b42 100644
--- a/tests/marltoolbox/envs/test_mixed_motive_vectorized_coin_game.py
+++ b/tests/marltoolbox/envs/test_vectorized_mixed_motive_coin_game.py
@@ -10,160 +10,88 @@
assert_obs_is_symmetrical,
assert_obs_is_not_symmetrical,
)
+from coin_game_tests_utils import (
+ check_custom_obs,
+ assert_logger_buffer_size,
+ helper_test_reset,
+ helper_test_step,
+ init_several_envs,
+ helper_test_multiple_steps,
+ helper_test_multi_ple_episodes,
+ helper_assert_info,
+)
# TODO add tests for grid_size != 3
-def test_reset():
- max_steps, batch_size, grid_size = 20, 5, 3
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
-
-def init_several_env(
+def init_my_envs(
max_steps,
batch_size,
grid_size,
players_can_pick_same_coin=True,
same_obs_for_each_player=True,
):
- mixed_motive_coin_game = init_env(
- max_steps,
- batch_size,
- VectMixedMotiveCG,
- grid_size,
+ return init_several_envs(
+ classes=(VectMixedMotiveCG,),
+ max_steps=max_steps,
+ grid_size=grid_size,
+ batch_size=batch_size,
players_can_pick_same_coin=players_can_pick_same_coin,
same_obs_for_each_player=same_obs_for_each_player,
)
- return [mixed_motive_coin_game]
-def init_env(
- max_steps,
- batch_size,
- env_class,
- seed=None,
- grid_size=3,
- players_can_pick_same_coin=True,
- same_obs_for_each_player=False,
-):
- config = {
- "max_steps": max_steps,
- "batch_size": batch_size,
- "grid_size": grid_size,
- "same_obs_for_each_player": same_obs_for_each_player,
- "both_players_can_pick_the_same_coin": players_can_pick_same_coin,
- }
- env = env_class(config)
- env.seed(seed)
- return env
+def check_obs(obs, batch_size, grid_size):
+ check_custom_obs(
+ obs, grid_size, n_in_2_and_above=2.0, batch_size=batch_size
+ )
-def check_obs(obs, batch_size, grid_size):
- assert len(obs) == 2, "two players"
- for i in range(batch_size):
- for key, player_obs in obs.items():
- assert player_obs.shape == (batch_size, grid_size, grid_size, 4)
- assert (
- player_obs[i, ..., 0].sum() == 1.0
- ), f"observe 1 player red in grid: {player_obs[i, ..., 0]}"
- assert (
- player_obs[i, ..., 1].sum() == 1.0
- ), f"observe 1 player blue in grid: {player_obs[i, ..., 1]}"
- assert (
- player_obs[i, ..., 2:].sum() == 2.0
- ), f"observe 1 coin in grid: {player_obs[i, ..., 0]}"
-
-
-def assert_logger_buffer_size(env, n_steps):
- assert len(env.red_pick) == n_steps
- assert len(env.red_pick_own) == n_steps
- assert len(env.blue_pick) == n_steps
- assert len(env.blue_pick_own) == n_steps
+def test_reset():
+ max_steps, batch_size, grid_size = 20, 5, 3
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+ helper_test_reset(
+ envs, check_obs, grid_size=grid_size, batch_size=batch_size
+ )
def test_step():
max_steps, batch_size, grid_size = 20, 5, 3
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- actions = {
- policy_id: [
- random.randint(0, env.NUM_ACTIONS - 1)
- for _ in range(batch_size)
- ]
- for policy_id in env.players_ids
- }
- obs, reward, done, info = env.step(actions)
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=1)
- assert not done["__all__"]
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+ helper_test_step(
+ envs,
+ check_obs,
+ grid_size=grid_size,
+ batch_size=batch_size,
+ )
def test_multiple_steps():
max_steps, batch_size, grid_size = 20, 5, 3
n_steps = int(max_steps * 0.75)
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- for step_i in range(1, n_steps, 1):
- actions = {
- policy_id: [
- random.randint(0, env.NUM_ACTIONS - 1)
- for _ in range(batch_size)
- ]
- for policy_id in env.players_ids
- }
- obs, reward, done, info = env.step(actions)
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=step_i)
- assert not done["__all__"]
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+ helper_test_multiple_steps(
+ envs,
+ n_steps,
+ check_obs,
+ grid_size=grid_size,
+ batch_size=batch_size,
+ )
def test_multiple_episodes():
max_steps, batch_size, grid_size = 20, 100, 3
n_steps = int(max_steps * 8.25)
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- step_i = 0
- for _ in range(n_steps):
- step_i += 1
- actions = {
- policy_id: [
- random.randint(0, env.NUM_ACTIONS - 1)
- for _ in range(batch_size)
- ]
- for policy_id in env.players_ids
- }
- obs, reward, done, info = env.step(actions)
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=step_i)
- assert not done["__all__"] or (
- step_i == max_steps and done["__all__"]
- )
- if done["__all__"]:
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- step_i = 0
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+ helper_test_multi_ple_episodes(
+ envs,
+ n_steps,
+ max_steps,
+ check_obs,
+ grid_size=grid_size,
+ batch_size=batch_size,
+ )
def overwrite_pos(
@@ -175,6 +103,7 @@ def overwrite_pos(
p_blue_pos,
c_red_pos,
c_blue_pos,
+ **kwargs,
):
assert len(p_red_pos) == n_steps_in_epi
assert len(p_blue_pos) == n_steps_in_epi
@@ -200,75 +129,6 @@ def overwrite_pos(
env.blue_coin_pos = np.array(env.blue_coin_pos)
-def assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- n_steps_in_epi,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed,
- blue_speed,
- red_own,
- blue_own,
-):
- step_i = 0
-
- for _ in range(n_steps):
- overwrite_pos(
- step_i,
- batch_deltas,
- n_steps_in_epi,
- env,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- )
- actions = {
- "player_red": [
- p_red_act[(step_i + delta) % n_steps_in_epi]
- for delta in batch_deltas
- ],
- "player_blue": [
- p_blue_act[(step_i + delta) % n_steps_in_epi]
- for delta in batch_deltas
- ],
- }
- step_i += 1
-
- obs, reward, done, info = env.step(actions)
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=step_i)
- assert not done["__all__"] or (
- step_i == n_steps_in_epi and done["__all__"]
- )
-
- if done["__all__"]:
- assert info["player_red"]["pick_speed"] == red_speed
- assert info["player_blue"]["pick_speed"] == blue_speed
-
- if red_own is None:
- assert "pick_own_color" not in info["player_red"]
- else:
- assert info["player_red"]["pick_own_color"] == red_own
- if blue_own is None:
- assert "pick_own_color" not in info["player_blue"]
- else:
- assert info["player_blue"]["pick_own_color"] == blue_own
-
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
- step_i = 0
-
-
def test_logged_info_no_picking():
p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
@@ -278,33 +138,27 @@ def test_logged_info_no_picking():
c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env in envs:
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.0,
- red_own=None,
- blue_own=None,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__red_pick_red_all_the_time():
@@ -316,33 +170,27 @@ def test_logged_info__red_pick_red_all_the_time():
c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.0,
- red_own=None,
- blue_own=None,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__blue_pick_red_all_the_time():
@@ -354,33 +202,27 @@ def test_logged_info__blue_pick_red_all_the_time():
c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.0,
- red_own=None,
- blue_own=None,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__blue_pick_blue_all_the_time():
@@ -392,33 +234,27 @@ def test_logged_info__blue_pick_blue_all_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.0,
- red_own=None,
- blue_own=None,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__red_pick_blue_all_the_time():
@@ -430,33 +266,27 @@ def test_logged_info__red_pick_blue_all_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.0,
- red_own=None,
- blue_own=None,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__red_pick_blue_all_the_time_wt_difference_in_actions():
@@ -468,33 +298,27 @@ def test_logged_info__red_pick_blue_all_the_time_wt_difference_in_actions():
c_blue_pos = [[1, 1], [1, 2], [2, 0], [0, 0]]
max_steps, batch_size, grid_size = 4, 4, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.0,
- red_own=None,
- blue_own=None,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__both_pick_blue_all_the_time():
@@ -506,33 +330,27 @@ def test_logged_info__both_pick_blue_all_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=1.0,
- blue_speed=1.0,
- red_own=0.0,
- blue_own=1.0,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=0.0,
+ blue_own=1.0,
+ )
def test_logged_info__both_pick_red_all_the_time():
@@ -544,33 +362,27 @@ def test_logged_info__both_pick_red_all_the_time():
c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=1.0,
- blue_speed=1.0,
- red_own=1.0,
- blue_own=0.0,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=1.0,
+ blue_own=0.0,
+ )
def test_logged_info__both_pick_red_half_the_time():
@@ -582,33 +394,27 @@ def test_logged_info__both_pick_red_half_the_time():
c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.0,
- red_own=None,
- blue_own=None,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__both_pick_blue_half_the_time():
@@ -620,33 +426,27 @@ def test_logged_info__both_pick_blue_half_the_time():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.0,
- red_own=None,
- blue_own=None,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__both_pick_blue():
@@ -658,33 +458,27 @@ def test_logged_info__both_pick_blue():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.0,
- red_own=None,
- blue_own=None,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__pick_half_the_time_half_blue_half_red():
@@ -696,33 +490,27 @@ def test_logged_info__pick_half_the_time_half_blue_half_red():
c_blue_pos = [[0, 0], [1, 1], [0, 0], [1, 1]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.0,
- blue_speed=0.0,
- red_own=None,
- blue_own=None,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
def test_logged_info__pick_slowly_red_coin():
@@ -734,33 +522,27 @@ def test_logged_info__pick_slowly_red_coin():
c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.25,
- blue_speed=0.25,
- red_own=1.0,
- blue_own=0.0,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.25,
+ blue_speed=0.25,
+ red_own=1.0,
+ blue_own=0.0,
+ )
def test_logged_info__pick_slowly_blue_coin():
@@ -772,33 +554,27 @@ def test_logged_info__pick_slowly_blue_coin():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.25,
- blue_speed=0.25,
- red_own=0.0,
- blue_own=1.0,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.25,
+ blue_speed=0.25,
+ red_own=0.0,
+ blue_own=1.0,
+ )
def test_logged_info__pick_quickly_red_coin():
@@ -810,33 +586,27 @@ def test_logged_info__pick_quickly_red_coin():
c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.75,
- blue_speed=0.75,
- red_own=1.0,
- blue_own=0.0,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.75,
+ blue_speed=0.75,
+ red_own=1.0,
+ blue_own=0.0,
+ )
def test_logged_info__pick_quickly_blue_coin():
@@ -848,33 +618,27 @@ def test_logged_info__pick_quickly_blue_coin():
c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.75,
- blue_speed=0.75,
- red_own=0.0,
- blue_own=1.0,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.75,
+ blue_speed=0.75,
+ red_own=0.0,
+ blue_own=1.0,
+ )
def test_logged_info__pick_slowly_mixed_coin():
@@ -886,33 +650,27 @@ def test_logged_info__pick_slowly_mixed_coin():
c_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=0.50,
- blue_speed=0.50,
- red_own=0.5,
- blue_own=0.5,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.50,
+ blue_speed=0.50,
+ red_own=0.5,
+ blue_own=0.5,
+ )
def test_logged_info__pick_quickly_mixed_coin():
@@ -924,39 +682,33 @@ def test_logged_info__pick_quickly_mixed_coin():
c_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1]]
max_steps, batch_size, grid_size = 4, 28, 3
n_steps = max_steps
- envs = init_several_env(max_steps, batch_size, grid_size)
-
- batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size)
-
- for env_i, env in enumerate(envs):
- obs = env.reset()
- check_obs(obs, batch_size, grid_size)
- assert_logger_buffer_size(env, n_steps=0)
-
- assert_info(
- batch_deltas,
- n_steps,
- batch_size,
- p_red_act,
- p_blue_act,
- env,
- grid_size,
- max_steps,
- p_red_pos,
- p_blue_pos,
- c_red_pos,
- c_blue_pos,
- red_speed=1.0,
- blue_speed=1.0,
- red_own=0.5,
- blue_own=0.5,
- )
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=0.5,
+ blue_own=0.5,
+ )
def test_get_and_set_env_state():
max_steps, batch_size, grid_size = 20, 100, 3
n_steps = int(max_steps * 8.25)
- envs = init_several_env(max_steps, batch_size, grid_size)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
for env in envs:
obs = env.reset()
@@ -1071,7 +823,7 @@ def test_observations_are_invariant_to_the_player_trained_wt_step():
]
max_steps, batch_size, grid_size = 10, 52, 3
n_steps = max_steps
- envs = init_several_env(
+ envs = init_my_envs(
max_steps, batch_size, grid_size, same_obs_for_each_player=False
)
@@ -1177,7 +929,7 @@ def test_observations_are_invariant_to_the_player_trained_wt_reset():
]
max_steps, batch_size, grid_size = 10, 52, 3
n_steps = max_steps
- envs = init_several_env(
+ envs = init_my_envs(
max_steps, batch_size, grid_size, same_obs_for_each_player=False
)
@@ -1273,7 +1025,7 @@ def test_observations_are_not_invariant_to_the_player_trained_wt_step():
]
max_steps, batch_size, grid_size = 10, 52, 3
n_steps = max_steps
- envs = init_several_env(
+ envs = init_my_envs(
max_steps, batch_size, grid_size, same_obs_for_each_player=True
)
@@ -1381,7 +1133,7 @@ def test_observations_are_not_invariant_to_the_player_trained_wt_reset():
]
max_steps, batch_size, grid_size = 10, 52, 3
n_steps = max_steps
- envs = init_several_env(
+ envs = init_my_envs(
max_steps, batch_size, grid_size, same_obs_for_each_player=True
)
diff --git a/tests/marltoolbox/envs/test_vectorized_ssd_mm_coin_game.py b/tests/marltoolbox/envs/test_vectorized_ssd_mm_coin_game.py
new file mode 100644
index 0000000..620dd4e
--- /dev/null
+++ b/tests/marltoolbox/envs/test_vectorized_ssd_mm_coin_game.py
@@ -0,0 +1,862 @@
+import copy
+import random
+
+import numpy as np
+from flaky import flaky
+
+from coin_game_tests_utils import (
+ check_custom_obs,
+ helper_test_reset,
+ helper_test_step,
+ init_several_envs,
+ helper_test_multiple_steps,
+ helper_test_multi_ple_episodes,
+ helper_assert_info,
+ shift_consistently,
+)
+from marltoolbox.envs.vectorized_ssd_mm_coin_game import (
+ VectSSDMixedMotiveCG,
+)
+from test_coin_game import (
+ assert_obs_is_symmetrical,
+ assert_obs_is_not_symmetrical,
+)
+
+
+# TODO add tests for grid_size != 3
+
+
+def init_my_envs(
+ max_steps,
+ batch_size,
+ grid_size,
+ players_can_pick_same_coin=True,
+ same_obs_for_each_player=True,
+):
+ return init_several_envs(
+ classes=(VectSSDMixedMotiveCG,),
+ max_steps=max_steps,
+ grid_size=grid_size,
+ batch_size=batch_size,
+ players_can_pick_same_coin=players_can_pick_same_coin,
+ same_obs_for_each_player=same_obs_for_each_player,
+ )
+
+
+def check_obs(obs, batch_size, grid_size):
+ check_custom_obs(
+ obs, grid_size, batch_size=batch_size, n_layers=6, n_in_2_and_above=2.0
+ )
+
+
+def test_reset():
+ max_steps, batch_size, grid_size = 20, 5, 3
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+ helper_test_reset(
+ envs, check_obs, grid_size=grid_size, batch_size=batch_size
+ )
+
+
+def test_step():
+ max_steps, batch_size, grid_size = 20, 5, 3
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+ helper_test_step(
+ envs,
+ check_obs,
+ grid_size=grid_size,
+ batch_size=batch_size,
+ )
+
+
+def test_multiple_steps():
+ max_steps, batch_size, grid_size = 20, 5, 3
+ n_steps = int(max_steps * 0.75)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+ helper_test_multiple_steps(
+ envs,
+ n_steps,
+ check_obs,
+ grid_size=grid_size,
+ batch_size=batch_size,
+ )
+
+
+def test_multiple_episodes():
+ max_steps, batch_size, grid_size = 20, 100, 3
+ n_steps = int(max_steps * 8.25)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+ helper_test_multi_ple_episodes(
+ envs,
+ n_steps,
+ max_steps,
+ check_obs,
+ grid_size=grid_size,
+ batch_size=batch_size,
+ )
+
+
+def overwrite_pos(
+ step_i,
+ batch_deltas,
+ n_steps_in_epi,
+ env,
+ p_red_pos,
+ p_blue_pos,
+ c_red_pos,
+ c_blue_pos,
+ c_red_coin,
+ **kwargs,
+):
+ assert len(c_red_coin) == n_steps_in_epi
+ assert len(p_red_pos) == n_steps_in_epi
+ assert len(p_blue_pos) == n_steps_in_epi
+ assert len(c_red_pos) == n_steps_in_epi
+ assert len(c_blue_pos) == n_steps_in_epi
+
+ env.red_coin = shift_consistently(
+ c_red_coin, step_i, n_steps_in_epi, batch_deltas
+ )
+ env.red_pos = shift_consistently(
+ p_red_pos, step_i, n_steps_in_epi, batch_deltas
+ )
+ env.blue_pos = shift_consistently(
+ p_blue_pos, step_i, n_steps_in_epi, batch_deltas
+ )
+ env.red_coin_pos = shift_consistently(
+ c_red_pos, step_i, n_steps_in_epi, batch_deltas
+ )
+ env.blue_coin_pos = shift_consistently(
+ c_blue_pos, step_i, n_steps_in_epi, batch_deltas
+ )
+
+ env.red_coin = np.array(env.red_coin, dtype=np.int8)
+ env.red_pos = np.array(env.red_pos)
+ env.blue_pos = np.array(env.blue_pos)
+ env.red_coin_pos = np.array(env.red_coin_pos)
+ env.blue_coin_pos = np.array(env.blue_coin_pos)
+
+
+def test_logged_info_no_picking():
+ p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
+ p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
+ p_red_act = [0, 0, 0, 0]
+ p_blue_act = [0, 0, 0, 0]
+ c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
+ c_blue_pos = [[2, 2], [2, 2], [2, 2], [2, 2]]
+ c_red_coin = [1, 1, 1, 1]
+ max_steps, batch_size, grid_size = 4, 28, 3
+ n_steps = max_steps
+
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ c_red_coin=c_red_coin,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
+
+
+def test_logged_info__red_pick_red_all_the_time():
+ p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]]
+ p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
+ p_red_act = [0, 0, 0, 0]
+ p_blue_act = [0, 0, 0, 0]
+ c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
+ c_blue_pos = [[2, 2], [2, 2], [2, 2], [2, 2]]
+ c_red_coin = [0, 0, 0, 0]
+ max_steps, batch_size, grid_size = 4, 28, 3
+ n_steps = max_steps
+
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ c_red_coin=c_red_coin,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=0.0,
+ red_own=1.0,
+ blue_own=None,
+ )
+
+
+def test_logged_info__blue_pick_red_all_the_time():
+ p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]]
+ p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]]
+ p_red_act = [0, 0, 0, 0]
+ p_blue_act = [0, 0, 0, 0]
+ c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
+ c_blue_pos = [[2, 2], [2, 2], [2, 2], [2, 2]]
+ c_red_coin = [1, 1, 1, 1]
+ max_steps, batch_size, grid_size = 4, 28, 3
+ n_steps = max_steps
+
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ c_red_coin=c_red_coin,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=1.0,
+ blue_own=0.0,
+ )
+
+
+def test_logged_info__blue_pick_blue_all_the_time():
+ p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
+ p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]]
+ p_red_act = [0, 0, 0, 0]
+ p_blue_act = [0, 0, 0, 0]
+ c_red_pos = [[2, 2], [2, 2], [2, 2], [2, 2]]
+ c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
+ c_red_coin = [1, 1, 1, 1]
+ max_steps, batch_size, grid_size = 4, 28, 3
+ n_steps = max_steps
+
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ c_red_coin=c_red_coin,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=1.0,
+ red_own=None,
+ blue_own=1.0,
+ )
+
+
+def test_logged_info__red_cant_pick_selfish_blue():
+ p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]]
+ p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]]
+ p_red_act = [0, 0, 0, 0]
+ p_blue_act = [0, 0, 0, 0]
+ c_red_pos = [[2, 2], [2, 2], [2, 2], [2, 2]]
+ c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
+ c_red_coin = [1, 1, 1, 1]
+ max_steps, batch_size, grid_size = 4, 28, 3
+ n_steps = max_steps
+
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ c_red_coin=c_red_coin,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
+
+
+def test_logged_info__both_pick_coop_blue_all_the_time_wt_difference_in_actions():
+ p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]]
+ p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]]
+ p_red_act = [0, 1, 2, 3]
+ p_blue_act = [0, 1, 2, 3]
+ c_red_pos = [[2, 2], [2, 2], [2, 2], [2, 2]]
+ c_blue_pos = [[1, 1], [1, 2], [2, 0], [0, 0]]
+ c_red_coin = [0, 0, 0, 0]
+ max_steps, batch_size, grid_size = 4, 4, 3
+ n_steps = max_steps
+
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ c_red_coin=c_red_coin,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=0.0,
+ blue_own=1.0,
+ )
+
+
+def test_logged_info__both_pick_blue_all_the_time():
+ p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]]
+ p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]]
+ p_red_act = [0, 0, 0, 0]
+ p_blue_act = [0, 0, 0, 0]
+ c_red_pos = [[2, 2], [2, 2], [2, 2], [2, 2]]
+ c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
+ c_red_coin = [0, 0, 0, 0]
+ max_steps, batch_size, grid_size = 4, 28, 3
+ n_steps = max_steps
+
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ c_red_coin=c_red_coin,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=0.0,
+ blue_own=1.0,
+ )
+
+
+def test_logged_info__both_pick_red_all_the_time():
+ p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]]
+ p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]]
+ p_red_act = [0, 0, 0, 0]
+ p_blue_act = [0, 0, 0, 0]
+ c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
+ c_blue_pos = [[2, 2], [2, 2], [2, 2], [2, 2]]
+ c_red_coin = [1, 1, 1, 1]
+ max_steps, batch_size, grid_size = 4, 28, 3
+ n_steps = max_steps
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ c_red_coin=c_red_coin,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=1.0,
+ blue_own=0.0,
+ )
+
+
+def test_logged_info__both_pick_coop_red_half_the_time():
+ p_red_pos = [[0, 0], [0, 0], [1, 0], [1, 0]]
+ p_blue_pos = [[0, 0], [0, 0], [1, 0], [1, 0]]
+ p_red_act = [0, 0, 0, 0]
+ p_blue_act = [0, 0, 0, 0]
+ c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
+ c_blue_pos = [[2, 2], [2, 2], [2, 2], [2, 2]]
+ c_red_coin = [1, 1, 1, 1]
+ max_steps, batch_size, grid_size = 4, 28, 3
+ n_steps = max_steps
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ c_red_coin=c_red_coin,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.5,
+ blue_speed=0.5,
+ red_own=1.0,
+ blue_own=0.0,
+ )
+
+
+def test_logged_info__both_pick_selfish_blue_half_the_time():
+ p_red_pos = [[0, 0], [0, 0], [1, 0], [1, 0]]
+ p_blue_pos = [[1, 0], [1, 0], [0, 0], [0, 0]]
+ p_red_act = [0, 0, 0, 0]
+ p_blue_act = [0, 0, 0, 0]
+ c_red_pos = [[2, 2], [2, 2], [2, 2], [2, 2]]
+ c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
+ c_red_coin = [1, 1, 1, 1]
+ max_steps, batch_size, grid_size = 4, 28, 3
+ n_steps = max_steps
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ c_red_coin=c_red_coin,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.5,
+ red_own=None,
+ blue_own=1.0,
+ )
+
+
+def test_logged_info__both_dont_pick_coop_red():
+ p_red_pos = [[0, 0], [0, 0], [0, 0], [1, 0]]
+ p_blue_pos = [[1, 0], [1, 0], [0, 0], [0, 0]]
+ p_red_act = [0, 0, 0, 0]
+ p_blue_act = [0, 0, 0, 0]
+ c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]]
+ c_blue_pos = [[2, 2], [2, 2], [2, 2], [2, 2]]
+ c_red_coin = [1, 1, 1, 1]
+ max_steps, batch_size, grid_size = 4, 28, 3
+ n_steps = max_steps
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ c_red_coin=c_red_coin,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.0,
+ blue_speed=0.0,
+ red_own=None,
+ blue_own=None,
+ )
+
+
+def test_logged_info__pick_half_the_time_half_selfish_blue_half_selfish_red():
+ p_red_pos = [[0, 0], [0, 0], [1, 0], [1, 0]]
+ p_blue_pos = [[1, 0], [1, 0], [0, 0], [0, 0]]
+ p_red_act = [0, 0, 0, 0]
+ p_blue_act = [0, 0, 0, 0]
+ c_red_pos = [[1, 1], [2, 2], [1, 1], [2, 2]]
+ c_blue_pos = [[2, 2], [1, 1], [2, 2], [1, 1]]
+ c_red_coin = [1, 1, 0, 0]
+ max_steps, batch_size, grid_size = 4, 28, 3
+ n_steps = max_steps
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ c_red_coin=c_red_coin,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=0.25,
+ blue_speed=0.25,
+ red_own=1.0,
+ blue_own=1.0,
+ )
+
+
+def test_get_and_set_env_state():
+ max_steps, batch_size, grid_size = 20, 100, 3
+ n_steps = int(max_steps * 8.25)
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ for env in envs:
+ obs = env.reset()
+ initial_env_state = env._save_env()
+ initial_env_state_saved = copy.deepcopy(initial_env_state)
+ env_initial = copy.deepcopy(env)
+
+ step_i = 0
+ for _ in range(n_steps):
+ step_i += 1
+ actions = {
+ policy_id: [
+ random.randint(0, env.NUM_ACTIONS - 1)
+ for _ in range(batch_size)
+ ]
+ for policy_id in env.players_ids
+ }
+ obs, reward, done, info = env.step(actions)
+
+ assert all(
+ [
+ v == initial_env_state_saved[k]
+ if not isinstance(v, np.ndarray)
+ else (v == initial_env_state_saved[k]).all()
+ for k, v in initial_env_state.items()
+ ]
+ )
+ env_state_after_step = env._save_env()
+ env_after_step = copy.deepcopy(env)
+
+ env._load_env(initial_env_state)
+ env_vars, env_initial_vars = vars(env), vars(env_initial)
+ env_vars.pop("np_random", None)
+ env_initial_vars.pop("np_random", None)
+ assert all(
+ [
+ v == env_initial_vars[k]
+ if not isinstance(v, np.ndarray)
+ else (v == env_initial_vars[k]).all()
+ for k, v in env_vars.items()
+ ]
+ )
+
+ env._load_env(env_state_after_step)
+ env_vars, env_after_step_vars = vars(env), vars(env_after_step)
+ env_vars.pop("np_random", None)
+ env_after_step_vars.pop("np_random", None)
+ assert all(
+ [
+ v == env_after_step_vars[k]
+ if not isinstance(v, np.ndarray)
+ else (v == env_after_step_vars[k]).all()
+ for k, v in env_vars.items()
+ ]
+ )
+
+ if done["__all__"]:
+ obs = env.reset()
+ step_i = 0
+
+
+def test_observations_are_not_invariant_to_the_player_trained_wt_step():
+ p_red_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [1, 1],
+ [2, 0],
+ [0, 1],
+ [2, 2],
+ [1, 2],
+ ]
+ p_blue_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 0],
+ [1, 2],
+ [2, 2],
+ ]
+ p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ c_red_pos = [
+ [1, 1],
+ [2, 2],
+ [0, 1],
+ [2, 2],
+ [2, 2],
+ [2, 2],
+ [0, 0],
+ [2, 2],
+ [2, 2],
+ [2, 1],
+ ]
+ c_blue_pos = [
+ [2, 2],
+ [1, 1],
+ [2, 2],
+ [0, 1],
+ [2, 2],
+ [2, 2],
+ [2, 2],
+ [0, 0],
+ [2, 1],
+ [2, 2],
+ ]
+ c_red_coin = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
+ max_steps, batch_size, grid_size = 10, 52, 3
+ n_steps = max_steps
+ envs = init_my_envs(
+ max_steps, batch_size, grid_size, same_obs_for_each_player=True
+ )
+
+ batch_deltas = [
+ i % max_steps if i % 2 == 0 else i % max_steps - 1
+ for i in range(batch_size)
+ ]
+
+ for env_i, env in enumerate(envs):
+ _ = env.reset()
+ step_i = 0
+
+ for _ in range(n_steps):
+ overwrite_pos(
+ step_i,
+ batch_deltas,
+ max_steps,
+ env,
+ p_red_pos,
+ p_blue_pos,
+ c_red_pos,
+ c_blue_pos,
+ c_red_coin,
+ )
+ actions = {
+ "player_red": [
+ p_red_act[(step_i + delta) % max_steps]
+ for delta in batch_deltas
+ ],
+ "player_blue": [
+ p_blue_act[(step_i + delta) % max_steps]
+ for delta in batch_deltas
+ ],
+ }
+ obs, reward, done, info = env.step(actions)
+
+ step_i += 1
+ # assert that observations are not
+ # symmetrical respective to the
+ # actions
+ if step_i % 2 == 1:
+ obs_step_odd = obs
+ elif step_i % 2 == 0:
+ assert np.any(
+ obs[env.players_ids[0]] != obs_step_odd[env.players_ids[1]]
+ )
+ assert np.any(
+ obs[env.players_ids[1]] != obs_step_odd[env.players_ids[0]]
+ )
+ assert_obs_is_not_symmetrical(obs, env)
+
+ if step_i == max_steps:
+ break
+
+
+def test_observations_are_not_invariant_to_the_player_trained_wt_reset():
+ p_red_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [1, 1],
+ [2, 0],
+ [0, 1],
+ [2, 2],
+ [1, 2],
+ ]
+ p_blue_pos = [
+ [0, 0],
+ [0, 0],
+ [1, 1],
+ [1, 1],
+ [1, 1],
+ [0, 0],
+ [0, 1],
+ [2, 0],
+ [1, 2],
+ [2, 2],
+ ]
+ p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ c_red_pos = [
+ [1, 1],
+ [2, 2],
+ [0, 1],
+ [2, 2],
+ [2, 2],
+ [2, 2],
+ [0, 0],
+ [2, 2],
+ [2, 2],
+ [2, 1],
+ ]
+ c_blue_pos = [
+ [2, 2],
+ [1, 1],
+ [2, 2],
+ [0, 1],
+ [2, 2],
+ [2, 2],
+ [2, 2],
+ [0, 0],
+ [2, 1],
+ [2, 2],
+ ]
+ c_red_coin = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ max_steps, batch_size, grid_size = 10, 52, 3
+ n_steps = max_steps
+ envs = init_my_envs(
+ max_steps, batch_size, grid_size, same_obs_for_each_player=True
+ )
+
+ batch_deltas = [
+ i % max_steps if i % 2 == 0 else i % max_steps - 1
+ for i in range(batch_size)
+ ]
+
+ for env_i, env in enumerate(envs):
+ obs = env.reset()
+ assert_obs_is_not_symmetrical(obs, env)
+ step_i = 0
+
+ for _ in range(n_steps):
+ overwrite_pos(
+ step_i,
+ batch_deltas,
+ max_steps,
+ env,
+ p_red_pos,
+ p_blue_pos,
+ c_red_pos,
+ c_blue_pos,
+ c_red_coin,
+ )
+ actions = {
+ "player_red": [
+ p_red_act[(step_i + delta) % max_steps]
+ for delta in batch_deltas
+ ],
+ "player_blue": [
+ p_blue_act[(step_i + delta) % max_steps]
+ for delta in batch_deltas
+ ],
+ }
+ _, _, _, _ = env.step(actions)
+
+ step_i += 1
+
+ if step_i == max_steps:
+ break
+
+
+@flaky(max_runs=4, min_passes=1)
+def test_who_pick_is_random():
+ size = 100
+ p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] * size
+ p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] * size
+ p_red_act = [0, 0, 0, 0] * size
+ p_blue_act = [0, 0, 0, 0] * size
+ c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] * size
+ c_blue_pos = [[2, 2], [2, 2], [2, 2], [2, 2]] * size
+ c_red_coin = [1, 1, 1, 1] * size
+ max_steps, batch_size, grid_size = int(4 * size), 28, 3
+ n_steps = max_steps
+
+ envs = init_my_envs(max_steps, batch_size, grid_size)
+
+ helper_assert_info(
+ n_steps=n_steps,
+ batch_size=batch_size,
+ p_red_act=p_red_act,
+ p_blue_act=p_blue_act,
+ envs=envs,
+ grid_size=grid_size,
+ max_steps=max_steps,
+ p_red_pos=p_red_pos,
+ p_blue_pos=p_blue_pos,
+ c_red_pos=c_red_pos,
+ c_blue_pos=c_blue_pos,
+ c_red_coin=c_red_coin,
+ check_obs_fn=check_obs,
+ overwrite_pos_fn=overwrite_pos,
+ red_speed=1.0,
+ blue_speed=1.0,
+ red_own=1.0,
+ blue_own=0.0,
+ repetitions=1,
+ )
diff --git a/tests/marltoolbox/examples_and_experiments/manual_test_end_to_end.py b/tests/marltoolbox/examples_and_experiments/manual_test_end_to_end.py
index 6668b41..90b3539 100644
--- a/tests/marltoolbox/examples_and_experiments/manual_test_end_to_end.py
+++ b/tests/marltoolbox/examples_and_experiments/manual_test_end_to_end.py
@@ -4,256 +4,316 @@
import ray
from marltoolbox.utils import postprocessing
-from marltoolbox.utils.miscellaneous import check_learning_achieved
+from marltoolbox.utils.tune_analysis import check_learning_achieved
def print_metrics_available(tune_analysis):
- print("metric available in tune_analysis:",
- tune_analysis.results_df.columns.tolist())
+ print(
+ "metric available in tune_analysis:",
+ tune_analysis.results_df.columns.tolist(),
+ )
def test_pg_ipd():
from marltoolbox.examples.rllib_api.pg_ipd import main
+
# Restart Ray defensively in case the ray connection is lost.
ray.shutdown()
tune_analysis = main(debug=False)
print_metrics_available(tune_analysis)
+ check_learning_achieved(tune_results=tune_analysis, max_=-75)
check_learning_achieved(
tune_results=tune_analysis,
- max_=-75)
+ min_=0.9,
+ metric="custom_metrics.DD_freq/player_row_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis,
min_=0.9,
- metric="custom_metrics.DD_freq/player_row_mean")
+ metric="custom_metrics.DD_freq/player_col_mean",
+ )
+
+
+def test_r2d2_ipd():
+ from marltoolbox.examples.rllib_api.pg_ipd import main
+
+ # Restart Ray defensively in case the ray connection is lost.
+ ray.shutdown()
+ tune_analysis = main(debug=False)
+ print_metrics_available(tune_analysis)
+ check_learning_achieved(tune_results=tune_analysis, max_=-75)
check_learning_achieved(
tune_results=tune_analysis,
min_=0.9,
- metric="custom_metrics.DD_freq/player_col_mean")
+ metric="custom_metrics.DD_freq/player_row_mean",
+ )
+ check_learning_achieved(
+ tune_results=tune_analysis,
+ min_=0.9,
+ metric="custom_metrics.DD_freq/player_col_mean",
+ )
def test_ltft_ipd():
from marltoolbox.experiments.rllib_api.ltft_various_env import main
+
ray.shutdown()
tune_analysis_self_play, tune_analysis_against_opponent = main(
debug=False,
env="IteratedPrisonersDilemma",
train_n_replicates=1,
- against_naive_opp=True)
+ against_naive_opp=True,
+ )
print_metrics_available(tune_analysis_self_play)
- check_learning_achieved(
- tune_results=tune_analysis_self_play,
- min_=-42)
+ check_learning_achieved(tune_results=tune_analysis_self_play, min_=-42)
check_learning_achieved(
tune_results=tune_analysis_self_play,
min_=0.9,
- metric="custom_metrics.CC_freq/player_row_mean")
+ metric="custom_metrics.CC_freq/player_row_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis_self_play,
min_=0.9,
- metric="custom_metrics.CC_freq/player_col_mean")
+ metric="custom_metrics.CC_freq/player_col_mean",
+ )
print_metrics_available(tune_analysis_against_opponent)
check_learning_achieved(
- tune_results=tune_analysis_against_opponent,
- max_=-75)
+ tune_results=tune_analysis_against_opponent, max_=-75
+ )
check_learning_achieved(
tune_results=tune_analysis_against_opponent,
min_=0.9,
- metric="custom_metrics.DD_freq/player_row_mean")
+ metric="custom_metrics.DD_freq/player_row_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis_against_opponent,
min_=0.9,
- metric="custom_metrics.DD_freq/player_col_mean")
+ metric="custom_metrics.DD_freq/player_col_mean",
+ )
def test_amtft_ipd():
from marltoolbox.experiments.rllib_api.amtft_various_env import main
+
ray.shutdown()
tune_analysis_per_welfare, analysis_metrics_per_mode = main(
- debug=False, train_n_replicates=1, filter_utilitarian=False,
- env="IteratedPrisonersDilemma")
+ debug=False,
+ train_n_replicates=1,
+ filter_utilitarian=False,
+ env="IteratedPrisonersDilemma",
+ )
for welfare_name, tune_analysis in tune_analysis_per_welfare.items():
print("welfare_name", welfare_name)
print_metrics_available(tune_analysis)
- check_learning_achieved(
- tune_results=tune_analysis, min_=-204)
+ check_learning_achieved(tune_results=tune_analysis, min_=-204)
check_learning_achieved(
tune_results=tune_analysis,
min_=0.9,
- metric="custom_metrics.CC_freq/player_row_mean"
+ metric="custom_metrics.CC_freq/player_row_mean",
)
check_learning_achieved(
tune_results=tune_analysis,
min_=0.9,
- metric="custom_metrics.CC_freq/player_col_mean"
+ metric="custom_metrics.CC_freq/player_col_mean",
)
def test_ppo_asym_coin_game():
from marltoolbox.examples.rllib_api.ppo_coin_game import main
+
ray.shutdown()
tune_analysis = main(debug=False, stop_iters=200)
print_metrics_available(tune_analysis)
- check_learning_achieved(
- tune_results=tune_analysis, min_=15)
+ check_learning_achieved(tune_results=tune_analysis, min_=15)
check_learning_achieved(
tune_results=tune_analysis,
min_=0.30,
- metric="custom_metrics.pick_speed/player_red_mean")
+ metric="custom_metrics.pick_speed/player_red_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis,
min_=0.30,
- metric="custom_metrics.pick_speed/player_blue_mean")
+ metric="custom_metrics.pick_speed/player_blue_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis,
min_=0.40,
max_=0.60,
- metric="custom_metrics.pick_own_color/player_red_mean")
+ metric="custom_metrics.pick_own_color/player_red_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis,
min_=0.40,
max_=0.60,
- metric="custom_metrics.pick_own_color/player_blue_mean")
+ metric="custom_metrics.pick_own_color/player_blue_mean",
+ )
def test_dqn_coin_game():
from marltoolbox.examples.rllib_api.dqn_coin_game import main
+
ray.shutdown()
tune_analysis = main(debug=False)
print_metrics_available(tune_analysis)
- check_learning_achieved(
- tune_results=tune_analysis, max_=20)
+ check_learning_achieved(tune_results=tune_analysis, max_=20)
check_learning_achieved(
tune_results=tune_analysis,
min_=0.5,
- metric="custom_metrics.pick_speed/player_red_mean")
+ metric="custom_metrics.pick_speed/player_red_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis,
min_=0.5,
- metric="custom_metrics.pick_speed/player_blue_mean")
+ metric="custom_metrics.pick_speed/player_blue_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis,
min_=0.40,
max_=0.6,
- metric="custom_metrics.pick_own_color/player_red_mean")
+ metric="custom_metrics.pick_own_color/player_red_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis,
min_=0.40,
max_=0.6,
- metric="custom_metrics.pick_own_color/player_blue_mean")
+ metric="custom_metrics.pick_own_color/player_blue_mean",
+ )
def test_dqn_wt_utilitarian_welfare_coin_game():
from marltoolbox.examples.rllib_api.dqn_wt_welfare import main
+
ray.shutdown()
tune_analysis = main(debug=False)
print_metrics_available(tune_analysis)
- check_learning_achieved(
- tune_results=tune_analysis, min_=50)
+ check_learning_achieved(tune_results=tune_analysis, min_=50)
check_learning_achieved(
tune_results=tune_analysis,
min_=0.3,
- metric="custom_metrics.pick_speed/player_red_mean")
+ metric="custom_metrics.pick_speed/player_red_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis,
min_=0.3,
- metric="custom_metrics.pick_speed/player_blue_mean")
+ metric="custom_metrics.pick_speed/player_blue_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis,
min_=0.95,
- metric="custom_metrics.pick_own_color/player_red_mean")
+ metric="custom_metrics.pick_own_color/player_red_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis,
min_=0.95,
- metric="custom_metrics.pick_own_color/player_blue_mean")
+ metric="custom_metrics.pick_own_color/player_blue_mean",
+ )
def test_dqn_wt_inequity_aversion_welfare_coin_game():
from marltoolbox.examples.rllib_api.dqn_wt_welfare import main
+
ray.shutdown()
- tune_analysis = main(debug=False,
- welfare=postprocessing.WELFARE_INEQUITY_AVERSION)
+ tune_analysis = main(
+ debug=False, welfare=postprocessing.WELFARE_INEQUITY_AVERSION
+ )
print_metrics_available(tune_analysis)
- check_learning_achieved(
- tune_results=tune_analysis, min_=50)
+ check_learning_achieved(tune_results=tune_analysis, min_=50)
check_learning_achieved(
tune_results=tune_analysis,
min_=0.25,
- metric="custom_metrics.pick_speed/player_red_mean")
+ metric="custom_metrics.pick_speed/player_red_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis,
min_=0.25,
- metric="custom_metrics.pick_speed/player_blue_mean")
+ metric="custom_metrics.pick_speed/player_blue_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis,
min_=0.9,
- metric="custom_metrics.pick_own_color/player_red_mean")
+ metric="custom_metrics.pick_own_color/player_red_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis,
min_=0.9,
- metric="custom_metrics.pick_own_color/player_blue_mean")
+ metric="custom_metrics.pick_own_color/player_blue_mean",
+ )
def test_ltft_coin_game():
from marltoolbox.experiments.rllib_api.ltft_various_env import main
+
ray.shutdown()
tune_analysis_self_play, tune_analysis_against_opponent = main(
- debug=False, env="CoinGame", train_n_replicates=1,
- against_naive_opp=True)
+ debug=False,
+ env="CoinGame",
+ train_n_replicates=1,
+ against_naive_opp=True,
+ )
print_metrics_available(tune_analysis_self_play)
- check_learning_achieved(
- tune_results=tune_analysis_self_play,
- min_=50)
+ check_learning_achieved(tune_results=tune_analysis_self_play, min_=50)
check_learning_achieved(
tune_results=tune_analysis_self_play,
min_=0.3,
- metric="custom_metrics.pick_speed/player_red_mean")
+ metric="custom_metrics.pick_speed/player_red_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis_self_play,
min_=0.3,
- metric="custom_metrics.pick_speed/player_blue_mean")
+ metric="custom_metrics.pick_speed/player_blue_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis_self_play,
min_=0.9,
- metric="custom_metrics.pick_own_color/player_red_mean")
+ metric="custom_metrics.pick_own_color/player_red_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis_self_play,
min_=0.9,
- metric="custom_metrics.pick_own_color/player_blue_mean")
+ metric="custom_metrics.pick_own_color/player_blue_mean",
+ )
print_metrics_available(tune_analysis_against_opponent)
check_learning_achieved(
- tune_results=tune_analysis_against_opponent,
- max_=20)
+ tune_results=tune_analysis_against_opponent, max_=20
+ )
check_learning_achieved(
tune_results=tune_analysis_against_opponent,
min_=0.3,
- metric="custom_metrics.pick_speed/player_red_mean")
+ metric="custom_metrics.pick_speed/player_red_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis_against_opponent,
min_=0.3,
- metric="custom_metrics.pick_speed/player_blue_mean")
+ metric="custom_metrics.pick_speed/player_blue_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis_against_opponent,
max_=0.6,
- metric="custom_metrics.pick_own_color/player_red_mean")
+ metric="custom_metrics.pick_own_color/player_red_mean",
+ )
check_learning_achieved(
tune_results=tune_analysis_against_opponent,
max_=0.6,
- metric="custom_metrics.pick_own_color/player_blue_mean")
+ metric="custom_metrics.pick_own_color/player_blue_mean",
+ )
def test_amtft_coin_game():
from marltoolbox.experiments.rllib_api.amtft_various_env import main
+
ray.shutdown()
tune_analysis_per_welfare, analysis_metrics_per_mode = main(
- debug=False, train_n_replicates=1, filter_utilitarian=False,
- env="CoinGame")
+ debug=False,
+ train_n_replicates=1,
+ filter_utilitarian=False,
+ env="CoinGame",
+ )
for welfare_name, tune_analysis in tune_analysis_per_welfare.items():
print("welfare_name", welfare_name)
print_metrics_available(tune_analysis)
- check_learning_achieved(
- tune_results=tune_analysis,
- min_=40)
+ check_learning_achieved(tune_results=tune_analysis, min_=40)
check_learning_achieved(
tune_results=tune_analysis,
min_=0.25,
- metric="custom_metrics.pick_speed/player_red_mean")
+ metric="custom_metrics.pick_speed/player_red_mean",
+ )
diff --git a/tests/marltoolbox/examples_and_experiments/test_end_to_end.py b/tests/marltoolbox/examples_and_experiments/test_end_to_end.py
index 0cac310..531d961 100644
--- a/tests/marltoolbox/examples_and_experiments/test_end_to_end.py
+++ b/tests/marltoolbox/examples_and_experiments/test_end_to_end.py
@@ -5,23 +5,37 @@ def test_pg_ipd():
from marltoolbox.examples.rllib_api.pg_ipd import main
ray.shutdown() # Restart Ray defensively in case the ray connection is lost.
- main(stop_iters=10, tf=False, debug=True)
+ main(debug=True)
+
+
+def test_r2d2_ipd():
+ from marltoolbox.examples.rllib_api.r2d2_ipd import main
+
+ ray.shutdown() # Restart Ray defensively in case the ray connection is lost.
+ main(debug=True)
def test_ppo_asym_coin_game():
from marltoolbox.examples.rllib_api.ppo_coin_game import main
ray.shutdown()
- main(debug=True, stop_iters=3, tf=False)
+ main(debug=True)
-def test_ppo_asym_coin_game():
+def test_dqn_coin_game():
from marltoolbox.examples.rllib_api.dqn_coin_game import main
ray.shutdown()
main(debug=True)
+def test_r2d2_cion_game():
+ from marltoolbox.examples.rllib_api.r2d2_coin_game import main
+
+ ray.shutdown() # Restart Ray defensively in case the ray connection is lost.
+ main(debug=True)
+
+
def test_ltft_ipd():
from marltoolbox.experiments.rllib_api.ltft_various_env import main
@@ -43,6 +57,13 @@ def test_amtft_ipd():
main(debug=True, env="IteratedPrisonersDilemma")
+def test_amtft_ipd_with_r2d2():
+ from marltoolbox.experiments.rllib_api.amtft_various_env import main
+
+ ray.shutdown()
+ main(debug=True, env="IteratedPrisonersDilemma", use_r2d2=True)
+
+
def test_amtft_iasymbos():
from marltoolbox.experiments.rllib_api.amtft_various_env import main
@@ -183,3 +204,17 @@ def test_adaptive_mechanism_design_tune_class_api_wt_rllib_policy():
ray.shutdown()
main(debug=True, use_rllib_policy=True)
+
+
+def test_amtft_vs_exploiter():
+ from marltoolbox.experiments.rllib_api.amtft_vs_lvl1_exploiter import main
+
+ ray.shutdown()
+ main(debug=True)
+
+
+def test_amtft_meta_game():
+ from marltoolbox.experiments.rllib_api.amtft_meta_game import main
+
+ ray.shutdown()
+ main(debug=True)
diff --git a/tests/marltoolbox/utils/test_exploration.py b/tests/marltoolbox/utils/test_exploration.py
index 7934f43..e496a33 100644
--- a/tests/marltoolbox/utils/test_exploration.py
+++ b/tests/marltoolbox/utils/test_exploration.py
@@ -7,10 +7,10 @@
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.utils.schedules import PiecewiseSchedule
-from marltoolbox.envs.coin_game import \
- CoinGame
-from marltoolbox.envs.matrix_sequential_social_dilemma import \
- IteratedPrisonersDilemma
+from marltoolbox.envs.coin_game import CoinGame
+from marltoolbox.envs.matrix_sequential_social_dilemma import (
+ IteratedPrisonersDilemma,
+)
from marltoolbox.utils import exploration
ROUNDING_ERROR = 1e-3
@@ -23,32 +23,36 @@ def assert_equal_wt_some_epsilon(v1, v2):
def test_clusterize_by_distance():
output = exploration.clusterize_by_distance(
- torch.Tensor([0.0, 0.4, 1.0, 1.4, 1.8, 3.0]), 0.5)
+ torch.Tensor([0.0, 0.4, 1.0, 1.4, 1.8, 3.0]), 0.5
+ )
assert_equal_wt_some_epsilon(
- output,
- torch.Tensor([0.2000, 0.2000, 1.4000, 1.4000, 1.4000, 3.0000]))
+ output, torch.Tensor([0.2000, 0.2000, 1.4000, 1.4000, 1.4000, 3.0000])
+ )
output = exploration.clusterize_by_distance(
- torch.Tensor([0.0, 0.5, 1.0, 1.4, 1.8, 3.0]), 0.5)
+ torch.Tensor([0.0, 0.5, 1.0, 1.4, 1.8, 3.0]), 0.5
+ )
assert_equal_wt_some_epsilon(
- output,
- torch.Tensor([0.0000, 0.5000, 1.4000, 1.4000, 1.4000, 3.0000]))
+ output, torch.Tensor([0.0000, 0.5000, 1.4000, 1.4000, 1.4000, 3.0000])
+ )
output = exploration.clusterize_by_distance(
- torch.Tensor([-10.0, -9.8, 1.0, 1.4, 1.8, 3.0]), 0.5)
+ torch.Tensor([-10.0, -9.8, 1.0, 1.4, 1.8, 3.0]), 0.5
+ )
assert_equal_wt_some_epsilon(
output,
- torch.Tensor([-9.9000, -9.9000, 1.4000, 1.4000, 1.4000, 3.0000]))
+ torch.Tensor([-9.9000, -9.9000, 1.4000, 1.4000, 1.4000, 3.0000]),
+ )
output = exploration.clusterize_by_distance(
- torch.Tensor([-1.0, -0.51, -0.1, 0.0, 0.1, 0.51, 1.0]), 0.5)
+ torch.Tensor([-1.0, -0.51, -0.1, 0.0, 0.1, 0.51, 1.0]), 0.5
+ )
assert_equal_wt_some_epsilon(
- output,
- torch.Tensor([0., 0., 0., 0., 0., 0., 0.]))
+ output, torch.Tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
+ )
class TestSoftQSchedule:
-
def set_class_to_test(self):
self.class_to_test = exploration.SoftQSchedule
@@ -56,58 +60,59 @@ def test__set_temperature_wt_explore(self):
self.set_class_to_test()
self.arrange_for_simple_ipd()
- self.softqschedule._set_temperature(
- explore=True, timestep=0)
+ self.softqschedule._set_temperature(explore=True, timestep=0)
assert self.softqschedule.temperature == self.initial_temperature
self.softqschedule._set_temperature(
- explore=True, timestep=self.temperature_timesteps)
+ explore=True, timestep=self.temperature_timesteps
+ )
assert self.softqschedule.temperature == self.final_temperature
self.softqschedule._set_temperature(
- explore=True, timestep=self.temperature_timesteps // 2)
- assert abs(self.softqschedule.temperature -
- (self.initial_temperature - self.final_temperature) / 2) < \
- ROUNDING_ERROR
+ explore=True, timestep=self.temperature_timesteps // 2
+ )
+ assert (
+ abs(
+ self.softqschedule.temperature
+ - (self.initial_temperature - self.final_temperature) / 2
+ )
+ < ROUNDING_ERROR
+ )
def test__set_temperature_wtout_explore(self):
self.set_class_to_test()
self.arrange_for_simple_ipd()
- self.softqschedule._set_temperature(
- explore=False, timestep=0)
+ self.softqschedule._set_temperature(explore=False, timestep=0)
assert self.softqschedule.temperature == 1.0
self.softqschedule._set_temperature(
- explore=False, timestep=self.temperature_timesteps)
+ explore=False, timestep=self.temperature_timesteps
+ )
assert self.softqschedule.temperature == 1.0
self.softqschedule._set_temperature(
- explore=False, timestep=self.temperature_timesteps // 2)
+ explore=False, timestep=self.temperature_timesteps // 2
+ )
assert self.softqschedule.temperature == 1.0
def test__set_temperature_wt_explore_wt_multi_steps_schedule(self):
self.class_to_test = exploration.SoftQSchedule
self.arrange_for_multi_step_wt_coin_game()
- self.softqschedule._set_temperature(
- explore=True, timestep=0)
+ self.softqschedule._set_temperature(explore=True, timestep=0)
assert self.softqschedule.temperature == 2.0
- self.softqschedule._set_temperature(
- explore=True, timestep=2000)
+ self.softqschedule._set_temperature(explore=True, timestep=2000)
assert self.softqschedule.temperature == 0.1
- self.softqschedule._set_temperature(
- explore=True, timestep=3000)
+ self.softqschedule._set_temperature(explore=True, timestep=3000)
assert self.softqschedule.temperature == 0.1
- self.softqschedule._set_temperature(
- explore=True, timestep=500)
+ self.softqschedule._set_temperature(explore=True, timestep=500)
assert abs(self.softqschedule.temperature - 1.25) < ROUNDING_ERROR
- self.softqschedule._set_temperature(
- explore=True, timestep=1500)
+ self.softqschedule._set_temperature(explore=True, timestep=1500)
assert abs(self.softqschedule.temperature - 0.3) < ROUNDING_ERROR
def arrange_for_simple_ipd(self):
@@ -122,24 +127,21 @@ def arrange_for_multi_step_wt_coin_game(self):
self.final_temperature = 0.0
self.temperature_timesteps = 0.0
self.temperature_schedule = PiecewiseSchedule(
- endpoints=[
- (0, 2.0),
- (1000, 0.5),
- (2000, 0.1)],
+ endpoints=[(0, 2.0), (1000, 0.5), (2000, 0.1)],
outside_value=0.1,
- framework="torch")
+ framework="torch",
+ )
self.init_coin_game_scheduler()
def init_ipd_scheduler(self):
self.softqschedule = self.init_scheduler(
IteratedPrisonersDilemma.ACTION_SPACE,
- IteratedPrisonersDilemma.OBSERVATION_SPACE
+ IteratedPrisonersDilemma.OBSERVATION_SPACE,
)
def init_coin_game_scheduler(self):
self.softqschedule = self.init_scheduler(
- CoinGame.ACTION_SPACE,
- CoinGame({}).OBSERVATION_SPACE
+ CoinGame.ACTION_SPACE, CoinGame({}).OBSERVATION_SPACE
)
def init_scheduler(self, action_space, obs_space):
@@ -158,8 +160,8 @@ def init_scheduler(self, action_space, obs_space):
action_space=action_space,
num_outputs=action_space.n,
name="fc",
- model_config=MODEL_DEFAULTS
- )
+ model_config=MODEL_DEFAULTS,
+ ),
)
def test__apply_temperature(self):
@@ -173,22 +175,30 @@ def test__apply_temperature(self):
)
def apply_and_assert_apply_temperature(self, temperature, inputs):
- action_distribution, action_dist_class = \
- self.set_temperature_and_get_args(temperature=temperature,
- inputs=inputs)
+ (
+ action_distribution,
+ action_dist_class,
+ ) = self.set_temperature_and_get_args(
+ temperature=temperature, inputs=inputs
+ )
new_action_distribution = self.softqschedule._apply_temperature(
- copy.deepcopy(action_distribution), action_dist_class)
+ copy.deepcopy(action_distribution), action_dist_class
+ )
assert all(
abs(n_v - v / self.softqschedule.temperature) < ROUNDING_ERROR
- for v, n_v in zip(action_distribution.inputs,
- new_action_distribution.inputs))
+ for v, n_v in zip(
+ action_distribution.inputs, new_action_distribution.inputs
+ )
+ )
def set_temperature_and_get_args(self, temperature, inputs):
action_dist_class = TorchCategorical
+ inputs = torch.tensor(inputs)
action_distribution = TorchCategorical(
- inputs, self.softqschedule.model, temperature=1.0)
+ inputs, self.softqschedule.model, temperature=1.0
+ )
self.softqschedule.temperature = temperature
return action_distribution, action_dist_class
@@ -196,8 +206,7 @@ def test_get_exploration_action_wtout_explore(self):
self.helper_test_get_exploration_action_wt_explore(explore=False)
def random_inputs(self):
- return np.random.random(
- size=(1, np.random.randint(1, 50, size=1)[0]))
+ return np.random.random(size=(1, np.random.randint(1, 50, size=1)[0]))
def random_timestep(self):
return np.random.randint(0, 10000, size=1)[0]
@@ -206,23 +215,26 @@ def random_temperature(self):
return np.random.random(size=1)[0] * 10 + 1e-9
def apply_and_assert_get_exploration_action(
- self, inputs, explore, timestep):
+ self, inputs, explore, timestep
+ ):
- initial_action_distribution, _ = \
- self.set_temperature_and_get_args(temperature=1.0,
- inputs=inputs)
+ initial_action_distribution, _ = self.set_temperature_and_get_args(
+ temperature=1.0, inputs=inputs
+ )
action_distribution = copy.deepcopy(initial_action_distribution)
_ = self.softqschedule.get_exploration_action(
- action_distribution,
- timestep=timestep,
- explore=explore
+ action_distribution, timestep=timestep, explore=explore
)
temperature = self.softqschedule.temperature if explore else 1.0
- errors = [abs(n_v - v / temperature)
- for v, n_v in zip(initial_action_distribution.inputs[0],
- action_distribution.inputs[0])]
+ errors = [
+ abs(n_v - v / temperature)
+ for v, n_v in zip(
+ initial_action_distribution.inputs[0],
+ action_distribution.inputs[0],
+ )
+ ]
assert all(err < ROUNDING_ERROR for err in errors), f"errors: {errors}"
def test_get_exploration_action_wt_explore(self):
@@ -236,11 +248,11 @@ def helper_test_get_exploration_action_wt_explore(self, explore):
self.apply_and_assert_get_exploration_action(
inputs=self.random_inputs(),
explore=explore,
- timestep=self.random_timestep())
+ timestep=self.random_timestep(),
+ )
class TestSoftQScheduleWtClustering(TestSoftQSchedule):
-
def set_class_to_test(self):
self.class_to_test = exploration.SoftQScheduleWtClustering
@@ -250,15 +262,14 @@ def helper_test_get_exploration_action_wt_explore(self, explore):
for inputs in self.get_inputs_list():
self.apply_and_assert_get_exploration_action(
- inputs=inputs,
- explore=explore,
- timestep=self.random_timestep())
+ inputs=inputs, explore=explore, timestep=self.random_timestep()
+ )
def get_inputs_list(self):
return [
[[1.0, 0.0]],
[[5.0, -1.0]],
[[1.0, 1.6]],
- [[101, -2.3]],
- [[65, 98, 13, 56, 123, 156, 84]],
+ [[101.0, -2.3]],
+ [[65.0, 98.0, 13.0, 56.0, 123.0, 156.0, 84.0]],
]
diff --git a/tests/marltoolbox/utils/test_log.py b/tests/marltoolbox/utils/test_log.py
index 4b6ec9a..79149b6 100644
--- a/tests/marltoolbox/utils/test_log.py
+++ b/tests/marltoolbox/utils/test_log.py
@@ -1,30 +1,37 @@
import numpy as np
+import torch
-from marltoolbox.utils.log import _add_entropy_to_log
+from marltoolbox.utils.log.log import add_entropy_to_log
def test__add_entropy_to_log():
to_log = {}
- train_batch = {"action_dist_inputs": np.array([[0.0, 1.0]])}
- to_log = _add_entropy_to_log(train_batch, to_log)
- assert_close(to_log[f"entropy_buffer_samples_avg"], 0.00, 0.001)
- assert_close(to_log[f"entropy_buffer_samples_single"], 0.00, 0.001)
+ train_batch = {"action_dist_inputs": torch.tensor([[0.0, 1.0]])}
+ to_log = add_entropy_to_log(train_batch, to_log)
+ assert_are_close(to_log[f"entropy_buffer_samples_avg"], 0.00, 0.001)
+ assert_are_close(to_log[f"entropy_buffer_samples_single"], 0.00, 0.001)
to_log = {}
- train_batch = {"action_dist_inputs": np.array([[0.75, 0.25]])}
- to_log = _add_entropy_to_log(train_batch, to_log)
- assert_close(to_log[f"entropy_buffer_samples_avg"], 0.562335145, 0.001)
- assert_close(to_log[f"entropy_buffer_samples_single"], 0.562335145, 0.001)
+ train_batch = {"action_dist_inputs": torch.tensor([[0.75, 0.25]])}
+ to_log = add_entropy_to_log(train_batch, to_log)
+ assert_are_close(to_log[f"entropy_buffer_samples_avg"], 0.562335145, 0.001)
+ assert_are_close(
+ to_log[f"entropy_buffer_samples_single"], 0.562335145, 0.001
+ )
to_log = {}
- train_batch = {"action_dist_inputs": np.array([[0.62, 0.12, 0.13, 0.13]])}
- to_log = _add_entropy_to_log(train_batch, to_log)
- assert_close(to_log[f"entropy_buffer_samples_avg"], 1.081271236, 0.001)
- assert_close(to_log[f"entropy_buffer_samples_single"], 1.081271236, 0.001)
+ train_batch = {
+ "action_dist_inputs": torch.tensor([[0.62, 0.12, 0.13, 0.13]])
+ }
+ to_log = add_entropy_to_log(train_batch, to_log)
+ assert_are_close(to_log[f"entropy_buffer_samples_avg"], 1.081271236, 0.001)
+ assert_are_close(
+ to_log[f"entropy_buffer_samples_single"], 1.081271236, 0.001
+ )
to_log = {}
train_batch = {
- "action_dist_inputs": np.array(
+ "action_dist_inputs": torch.tensor(
[
[0.62, 0.12, 0.13, 0.13],
[0.75, 0.25, 0.0, 0.0],
@@ -32,13 +39,13 @@ def test__add_entropy_to_log():
]
)
}
- to_log = _add_entropy_to_log(train_batch, to_log)
- assert_close(to_log[f"entropy_buffer_samples_avg"], 0.547868794, 0.001)
- assert_close(to_log[f"entropy_buffer_samples_single"], 0.00, 0.001)
+ to_log = add_entropy_to_log(train_batch, to_log)
+ assert_are_close(to_log[f"entropy_buffer_samples_avg"], 0.547868794, 0.001)
+ assert_are_close(to_log[f"entropy_buffer_samples_single"], 0.00, 0.001)
return to_log
-def assert_close(a, b, threshold):
+def assert_are_close(a, b, threshold):
abs_diff = np.abs(a - b)
assert abs_diff < threshold
diff --git a/tests/marltoolbox/utils/test_rollout.py b/tests/marltoolbox/utils/test_rollout.py
index 29a58ac..bf1c7f9 100644
--- a/tests/marltoolbox/utils/test_rollout.py
+++ b/tests/marltoolbox/utils/test_rollout.py
@@ -1,17 +1,17 @@
import copy
import os
import tempfile
-import time
import numpy as np
+import time
from ray.rllib.agents.pg import PGTrainer, PGTorchPolicy
from ray.tune.logger import UnifiedLogger
from ray.tune.result import DEFAULT_RESULTS_DIR
-from marltoolbox.examples.rllib_api.pg_ipd import get_rllib_config
from marltoolbox.envs.matrix_sequential_social_dilemma import (
IteratedPrisonersDilemma,
)
+from marltoolbox.examples.rllib_api.pg_ipd import get_rllib_config
from marltoolbox.utils import log, miscellaneous
from marltoolbox.utils import rollout
@@ -19,7 +19,89 @@
EPI_LENGTH = 33
-class FakeEnvWtCstReward(IteratedPrisonersDilemma):
+def test_rollout_actions_played_equal_actions_specified():
+ policy_agent_mapping = lambda policy_id: policy_id
+ assert_actions_played_equal_actions_specified(
+ policy_agent_mapping,
+ rollout_length=20,
+ num_episodes=1,
+ actions_list=[0, 1] * 100,
+ )
+ assert_actions_played_equal_actions_specified(
+ policy_agent_mapping,
+ rollout_length=40,
+ num_episodes=1,
+ actions_list=[1, 1] * 100,
+ )
+ assert_actions_played_equal_actions_specified(
+ policy_agent_mapping,
+ rollout_length=77,
+ num_episodes=2,
+ actions_list=[0, 0] * 100,
+ )
+ assert_actions_played_equal_actions_specified(
+ policy_agent_mapping,
+ rollout_length=77,
+ num_episodes=3,
+ actions_list=[0, 1] * 100,
+ )
+ assert_actions_played_equal_actions_specified(
+ policy_agent_mapping,
+ rollout_length=6,
+ num_episodes=3,
+ actions_list=[1, 0] * 100,
+ )
+
+
+def assert_actions_played_equal_actions_specified(
+ policy_agent_mapping, rollout_length, num_episodes, actions_list
+):
+ rollout_results, worker = _when_perform_rollouts_wt_given_actions(
+ actions_list, rollout_length, policy_agent_mapping, num_episodes
+ )
+
+ _assert_length_of_rollout(rollout_results, num_episodes, rollout_length)
+
+ n_steps_in_last_epi, steps_in_last_epi = _compute_n_steps_in_last_epi(
+ rollout_results, rollout_length, num_episodes
+ )
+
+ all_steps = _unroll_all_steps(rollout_results)
+
+ # Verify that the actions played are the actions we forced to play
+ _for_each_player_exec_fn(
+ worker,
+ _assert_played_the_actions_specified,
+ all_steps,
+ rollout_length,
+ num_episodes,
+ actions_list,
+ )
+ _for_each_player_exec_fn(
+ worker,
+ _assert_played_the_actions_specified_during_last_epi_only,
+ all_steps,
+ n_steps_in_last_epi,
+ steps_in_last_epi,
+ actions_list,
+ )
+
+
+def _when_perform_rollouts_wt_given_actions(
+ actions_list, rollout_length, policy_agent_mapping, num_episodes
+):
+ worker = _init_worker(actions_list=actions_list)
+ rollout_results = rollout.internal_rollout(
+ worker,
+ num_steps=rollout_length,
+ policy_agent_mapping=policy_agent_mapping,
+ reset_env_before=True,
+ num_episodes=num_episodes,
+ )
+ return rollout_results, worker
+
+
+class _FakeEnvWtCstReward(IteratedPrisonersDilemma):
def step(self, actions: dict):
observations, rewards, epi_is_done, info = super().step(actions)
@@ -29,33 +111,38 @@ def step(self, actions: dict):
return observations, rewards, epi_is_done, info
-def make_FakePolicyWtDefinedActions(list_actions_to_play):
+def _make_fake_policy_wt_defined_actions(list_actions_to_play):
class FakePolicyWtDefinedActions(PGTorchPolicy):
- def compute_actions(self, *args, **kwargs):
+ def _compute_action_helper(self, *args, **kwargs):
action = list_actions_to_play.pop(0)
return np.array([action]), [], {}
+ def _initialize_loss_from_dummy_batch(
+ self,
+ auto_remove_unneeded_view_reqs: bool = True,
+ stats_fn=None,
+ ) -> None:
+ pass
+
return FakePolicyWtDefinedActions
-def init_worker(actions_list=None):
+def _init_worker(actions_list=None):
train_n_replicates = 1
debug = True
- stop_iters = 200
- tf = False
seeds = miscellaneous.get_random_seeds(train_n_replicates)
exp_name, _ = log.log_in_current_day_dir("testing")
- rllib_config, stop_config = get_rllib_config(seeds, debug, stop_iters, tf)
- rllib_config["env"] = FakeEnvWtCstReward
+ rllib_config, stop_config = get_rllib_config(seeds, debug)
+ rllib_config["env"] = _FakeEnvWtCstReward
rllib_config["env_config"]["max_steps"] = EPI_LENGTH
rllib_config["seed"] = int(time.time())
if actions_list is not None:
- for policy_id in FakeEnvWtCstReward({}).players_ids:
+ for policy_id in _FakeEnvWtCstReward({}).players_ids:
policy_to_modify = list(
rllib_config["multiagent"]["policies"][policy_id]
)
- policy_to_modify[0] = make_FakePolicyWtDefinedActions(
+ policy_to_modify[0] = _make_fake_policy_wt_defined_actions(
copy.deepcopy(actions_list)
)
rllib_config["multiagent"]["policies"][
@@ -93,105 +180,133 @@ def default_logger_creator(config):
return default_logger_creator
-def test_rollout_constant_reward():
- policy_agent_mapping = lambda policy_id: policy_id
+def _for_each_player_exec_fn(worker, fn, *arg, **kwargs):
+ for policy_id in worker.env.players_ids:
+ fn(policy_id, *arg, **kwargs)
- def assert_(rollout_length, num_episodes):
- worker = init_worker()
- rollout_results = rollout.internal_rollout(
- worker,
- num_steps=rollout_length,
- policy_agent_mapping=policy_agent_mapping,
- reset_env_before=True,
- num_episodes=num_episodes,
- )
- assert (
- rollout_results._num_episodes == num_episodes
- or rollout_results._total_steps == rollout_length
- )
- steps_in_last_epi = rollout_results._current_rollout
- if rollout_results._total_steps == rollout_length:
- n_steps_in_last_epi = rollout_results._total_steps % EPI_LENGTH
- elif rollout_results._num_episodes == num_episodes:
- n_steps_in_last_epi = EPI_LENGTH
-
- # Verify rewards
- for policy_id in worker.env.players_ids:
- rewards = [step[3][policy_id] for step in steps_in_last_epi]
- assert sum(rewards) == n_steps_in_last_epi * CONSTANT_REWARD
- assert len(rewards) == n_steps_in_last_epi
- all_steps = []
- for epi_rollout in rollout_results._rollouts:
- all_steps.extend(epi_rollout)
- for policy_id in worker.env.players_ids:
- rewards = [step[3][policy_id] for step in all_steps]
- assert (
- sum(rewards)
- == min(rollout_length, num_episodes * EPI_LENGTH)
- * CONSTANT_REWARD
- )
- assert len(rewards) == min(
- rollout_length, num_episodes * EPI_LENGTH
- )
+def _assert_played_the_actions_specified(
+ policy_id, all_steps, rollout_length, num_episodes, actions_list
+):
+ actions_played = [step[1][policy_id] for step in all_steps]
+ assert len(actions_played) == min(
+ rollout_length, num_episodes * EPI_LENGTH
+ )
+ for action_required, action_played in zip(
+ actions_list[: len(all_steps)], actions_played
+ ):
+ assert action_required == action_played
- assert_(rollout_length=20, num_episodes=1)
- assert_(rollout_length=40, num_episodes=1)
- assert_(rollout_length=77, num_episodes=2)
- assert_(rollout_length=77, num_episodes=3)
- assert_(rollout_length=6, num_episodes=3)
+def _assert_played_the_actions_specified_during_last_epi_only(
+ policy_id, all_steps, n_steps_in_last_epi, steps_in_last_epi, actions_list
+):
+ actions_played = [step[1][policy_id] for step in steps_in_last_epi]
+ assert len(actions_played) == n_steps_in_last_epi
+ actions_required_during_last_epi = actions_list[: len(all_steps)][
+ -n_steps_in_last_epi:
+ ]
+ for action_required, action_played in zip(
+ actions_required_during_last_epi, actions_played
+ ):
+ assert action_required == action_played
-def test_rollout_specified_actions():
+
+def _assert_length_of_rollout(rollout_results, num_episodes, rollout_length):
+ assert (
+ rollout_results._num_episodes == num_episodes
+ or rollout_results._total_steps == rollout_length
+ )
+
+
+def _compute_n_steps_in_last_epi(
+ rollout_results, rollout_length, num_episodes
+):
+ steps_in_last_epi = rollout_results._current_rollout
+ if rollout_results._total_steps == rollout_length:
+ n_steps_in_last_epi = rollout_results._total_steps % EPI_LENGTH
+ elif rollout_results._num_episodes == num_episodes:
+ n_steps_in_last_epi = EPI_LENGTH
+
+ assert n_steps_in_last_epi == len(
+ steps_in_last_epi
+ ), f"{n_steps_in_last_epi} == {len(steps_in_last_epi)}"
+
+ return n_steps_in_last_epi, steps_in_last_epi
+
+
+def _unroll_all_steps(rollout_results):
+ all_steps = []
+ for epi_rollout in rollout_results._rollouts:
+ all_steps.extend(epi_rollout)
+ return all_steps
+
+
+def test_rollout_rewards_received_equal_constant_reward():
policy_agent_mapping = lambda policy_id: policy_id
+ assert_rewards_received_are_rewards_specified(
+ policy_agent_mapping, rollout_length=20, num_episodes=1
+ )
+ assert_rewards_received_are_rewards_specified(
+ policy_agent_mapping, rollout_length=40, num_episodes=1
+ )
+ assert_rewards_received_are_rewards_specified(
+ policy_agent_mapping, rollout_length=77, num_episodes=2
+ )
+ assert_rewards_received_are_rewards_specified(
+ policy_agent_mapping, rollout_length=77, num_episodes=3
+ )
+ assert_rewards_received_are_rewards_specified(
+ policy_agent_mapping, rollout_length=6, num_episodes=3
+ )
- def assert_(rollout_length, num_episodes, actions_list):
- worker = init_worker(actions_list=actions_list)
- rollout_results = rollout.internal_rollout(
- worker,
- num_steps=rollout_length,
- policy_agent_mapping=policy_agent_mapping,
- reset_env_before=True,
- num_episodes=num_episodes,
- )
- assert (
- rollout_results._num_episodes == num_episodes
- or rollout_results._total_steps == rollout_length
- )
- steps_in_last_epi = rollout_results._current_rollout
- if rollout_results._total_steps == rollout_length:
- n_steps_in_last_epi = rollout_results._total_steps % EPI_LENGTH
- elif rollout_results._num_episodes == num_episodes:
- n_steps_in_last_epi = EPI_LENGTH
-
- # Verify actions
- all_steps = []
- for epi_rollout in rollout_results._rollouts:
- all_steps.extend(epi_rollout)
- for policy_id in worker.env.players_ids:
- actions_played = [step[1][policy_id] for step in all_steps]
- assert len(actions_played) == min(
- rollout_length, num_episodes * EPI_LENGTH
- )
- print(actions_list[1 : 1 + len(all_steps)], actions_played)
- for action_required, action_played in zip(
- actions_list[: len(all_steps)], actions_played
- ):
- assert action_required == action_played
- for policy_id in worker.env.players_ids:
- actions_played = [step[1][policy_id] for step in steps_in_last_epi]
- assert len(actions_played) == n_steps_in_last_epi
- actions_required_during_last_epi = actions_list[: len(all_steps)][
- -n_steps_in_last_epi:
- ]
- for action_required, action_played in zip(
- actions_required_during_last_epi, actions_played
- ):
- assert action_required == action_played
-
- assert_(rollout_length=20, num_episodes=1, actions_list=[0, 1] * 100)
- assert_(rollout_length=40, num_episodes=1, actions_list=[1, 1] * 100)
- assert_(rollout_length=77, num_episodes=2, actions_list=[0, 0] * 100)
- assert_(rollout_length=77, num_episodes=3, actions_list=[0, 1] * 100)
- assert_(rollout_length=6, num_episodes=3, actions_list=[1, 0] * 100)
+def assert_rewards_received_are_rewards_specified(
+ policy_agent_mapping, rollout_length, num_episodes
+):
+ rollout_results, worker = _when_perform_rollouts_wt_given_actions(
+ None, rollout_length, policy_agent_mapping, num_episodes
+ )
+
+ _assert_length_of_rollout(rollout_results, num_episodes, rollout_length)
+
+ n_steps_in_last_epi, steps_in_last_epi = _compute_n_steps_in_last_epi(
+ rollout_results, rollout_length, num_episodes
+ )
+
+ all_steps = _unroll_all_steps(rollout_results)
+
+ # Verify that the rewards received are the one we defined
+ _for_each_player_exec_fn(
+ worker,
+ _assert_rewards_in_last_epi_are_as_specified,
+ steps_in_last_epi,
+ n_steps_in_last_epi,
+ )
+
+ _for_each_player_exec_fn(
+ worker,
+ _assert_rewards_are_as_defined,
+ all_steps,
+ rollout_length,
+ num_episodes,
+ )
+
+
+def _assert_rewards_in_last_epi_are_as_specified(
+ policy_id, steps_in_last_epi, n_steps_in_last_epi
+):
+ rewards = [step[3][policy_id] for step in steps_in_last_epi]
+ assert sum(rewards) == n_steps_in_last_epi * CONSTANT_REWARD
+ assert len(rewards) == n_steps_in_last_epi
+
+
+def _assert_rewards_are_as_defined(
+ policy_id, all_steps, rollout_length, num_episodes
+):
+ rewards = [step[3][policy_id] for step in all_steps]
+ assert (
+ sum(rewards)
+ == min(rollout_length, num_episodes * EPI_LENGTH) * CONSTANT_REWARD
+ )
+ assert len(rewards) == min(rollout_length, num_episodes * EPI_LENGTH)
diff --git a/tests/marltoolbox/utils/test_same_cross_perf.py b/tests/marltoolbox/utils/test_same_cross_perf.py
index 4997e8e..bfb8a29 100644
--- a/tests/marltoolbox/utils/test_same_cross_perf.py
+++ b/tests/marltoolbox/utils/test_same_cross_perf.py
@@ -5,8 +5,7 @@
from ray.rllib.agents.pg import PGTrainer
from marltoolbox.examples.rllib_api.pg_ipd import get_rllib_config
-from marltoolbox.utils import log, miscellaneous, restore
-from marltoolbox.utils import self_and_cross_perf
+from marltoolbox.utils import log, miscellaneous, restore, cross_play
from marltoolbox.utils.miscellaneous import get_random_seeds
@@ -15,7 +14,7 @@ def _init_evaluator():
rllib_config, stop_config = get_rllib_config(seeds=get_random_seeds(1))
- evaluator = self_and_cross_perf.SelfAndCrossPlayEvaluator(
+ evaluator = cross_play.evaluator.SelfAndCrossPlayEvaluator(
exp_name=exp_name,
)
evaluator.define_the_experiment_to_run(
@@ -34,9 +33,10 @@ def _train_pg_in_ipd(train_n_replicates):
seeds = miscellaneous.get_random_seeds(train_n_replicates)
exp_name, _ = log.log_in_current_day_dir("testing")
+ ray.shutdown()
ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug)
- rllib_config, stop_config = get_rllib_config(seeds, debug, stop_iters, tf)
+ rllib_config, stop_config = get_rllib_config(seeds, debug)
tune_analysis = tune.run(
PGTrainer,
config=rllib_config,