Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ marltoolbox/examples/Tutorial_*.py
marltoolbox/experiments/Tutorial_*.py
api_key_wandb
wandb/
requirements.txt
18 changes: 18 additions & 0 deletions marltoolbox/algos/amTFT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,21 @@
AmTFTRolloutsTorchPolicy,
)
from marltoolbox.algos.amTFT.train_helper import train_amtft


__all__ = [
"train_amtft",
"AmTFTRolloutsTorchPolicy",
"amTFTQValuesTorchPolicy",
"Level1amTFTExploiterTorchPolicy",
"observation_fn",
"AmTFTCallbacks",
"WORKING_STATES",
"WORKING_STATES_IN_EVALUATION",
"AmTFTReferenceClass",
"DEFAULT_NESTED_POLICY_COOP",
"DEFAULT_NESTED_POLICY_SELFISH",
"PLOT_ASSEMBLAGE_TAGS",
"PLOT_KEYS",
"DEFAULT_CONFIG",
]
60 changes: 35 additions & 25 deletions marltoolbox/algos/amTFT/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@
from ray.rllib.utils import merge_dicts

from marltoolbox.algos import hierarchical, augmented_dqn
from marltoolbox.utils import \
postprocessing, miscellaneous
from marltoolbox.utils import postprocessing, miscellaneous

logger = logging.getLogger(__name__)

APPROXIMATION_METHOD_Q_VALUE = "amTFT_use_Q_net"
APPROXIMATION_METHOD_ROLLOUTS = "amTFT_use_rollout"
APPROXIMATION_METHODS = (APPROXIMATION_METHOD_Q_VALUE,
APPROXIMATION_METHOD_ROLLOUTS)
WORKING_STATES = ("train_coop",
"train_selfish",
"eval_amtft",
"eval_naive_selfish",
"eval_naive_coop")
APPROXIMATION_METHODS = (
APPROXIMATION_METHOD_Q_VALUE,
APPROXIMATION_METHOD_ROLLOUTS,
)
WORKING_STATES = (
"train_coop",
"train_selfish",
"eval_amtft",
"eval_naive_selfish",
"eval_naive_coop",
)
WORKING_STATES_IN_EVALUATION = WORKING_STATES[2:]

OWN_COOP_POLICY_IDX = 0
Expand All @@ -29,8 +32,9 @@
DEFAULT_NESTED_POLICY_COOP = DEFAULT_NESTED_POLICY_SELFISH.with_updates(
postprocess_fn=miscellaneous.merge_policy_postprocessing_fn(
postprocessing.welfares_postprocessing_fn(
add_utilitarian_welfare=True, ),
postprocess_nstep_and_prio
add_utilitarian_welfare=True,
),
postprocess_nstep_and_prio,
)
)

Expand All @@ -44,39 +48,44 @@
"rollout_length": 40,
"n_rollout_replicas": 20,
"last_k": 1,
"punish_instead_of_selfish": False,
# TODO use log level of RLLib instead of mine
"verbose": 1,
"auto_load_checkpoint": True,

# To configure
"own_policy_id": None,
"opp_policy_id": None,
"callbacks": None,
# One from marltoolbox.utils.postprocessing.WELFARES
"welfare_key": postprocessing.WELFARE_UTILITARIAN,

'nested_policies': [
"nested_policies": [
# Here the trainer need to be a DQNTrainer
# to provide the config for the 3 DQNTorchPolicy
{"Policy_class": DEFAULT_NESTED_POLICY_COOP,
"config_update": {}},
{"Policy_class": DEFAULT_NESTED_POLICY_SELFISH,
"config_update": {}},
{"Policy_class": DEFAULT_NESTED_POLICY_COOP,
"config_update": {}},
{"Policy_class": DEFAULT_NESTED_POLICY_SELFISH,
"config_update": {}},
{"Policy_class": DEFAULT_NESTED_POLICY_COOP, "config_update": {}},
{
"Policy_class": DEFAULT_NESTED_POLICY_SELFISH,
"config_update": {},
},
{"Policy_class": DEFAULT_NESTED_POLICY_COOP, "config_update": {}},
{
"Policy_class": DEFAULT_NESTED_POLICY_SELFISH,
"config_update": {},
},
],
"optimizer": {"sgd_momentum": 0.0, },
}
"optimizer": {
"sgd_momentum": 0.0,
},
"batch_mode": "complete_episodes",
},
)

PLOT_KEYS = [
"punish",
"debit",
"debit_threshold",
"summed_debit",
"summed_n_steps_to_punish"
"summed_n_steps_to_punish",
"reset_rnn_state",
]

PLOT_ASSEMBLAGE_TAGS = [
Expand All @@ -85,6 +94,7 @@
("debit_threshold",),
("summed_debit",),
("summed_n_steps_to_punish",),
("reset_rnn_state",),
]


Expand Down
Loading